Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d66e6988
Unverified
Commit
d66e6988
authored
Jan 18, 2024
by
Frank Lee
Committed by
GitHub
Jan 18, 2024
Browse files
Merge pull request #5278 from ver217/sync/npu
[sync] sync npu branch with main
parents
9102d655
14846934
Changes
152
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
212 additions
and
704 deletions
+212
-704
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+9
-3
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+9
-3
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+20
-3
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+7
-7
tests/test_checkpoint_io/test_gemini_torch_compability.py
tests/test_checkpoint_io/test_gemini_torch_compability.py
+1
-1
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
...heckpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+19
-13
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
...t_checkpoint_io/test_plugins_huggingface_compatibility.py
+1
-1
tests/test_infer_ops/triton/kernel_utils.py
tests/test_infer_ops/triton/kernel_utils.py
+0
-27
tests/test_infer_ops/triton/test_bloom_context_attention.py
tests/test_infer_ops/triton/test_bloom_context_attention.py
+0
-52
tests/test_infer_ops/triton/test_copy_kv_dest.py
tests/test_infer_ops/triton/test_copy_kv_dest.py
+0
-39
tests/test_infer_ops/triton/test_layernorm_triton.py
tests/test_infer_ops/triton/test_layernorm_triton.py
+0
-43
tests/test_infer_ops/triton/test_llama_act_combine.py
tests/test_infer_ops/triton/test_llama_act_combine.py
+0
-56
tests/test_infer_ops/triton/test_llama_context_attention.py
tests/test_infer_ops/triton/test_llama_context_attention.py
+0
-50
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
+0
-143
tests/test_infer_ops/triton/test_softmax.py
tests/test_infer_ops/triton/test_softmax.py
+0
-36
tests/test_infer_ops/triton/test_token_attn_fwd.py
tests/test_infer_ops/triton/test_token_attn_fwd.py
+0
-72
tests/test_infer_ops/triton/test_token_softmax.py
tests/test_infer_ops/triton/test_token_softmax.py
+0
-48
tests/test_lazy/test_models.py
tests/test_lazy/test_models.py
+8
-3
tests/test_pipeline/test_p2p_communication.py
tests/test_pipeline/test_p2p_communication.py
+49
-23
tests/test_pipeline/test_schedule/test_interleaved.py
tests/test_pipeline/test_schedule/test_interleaved.py
+89
-81
No files found.
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
View file @
d66e6988
...
...
@@ -10,8 +10,8 @@ from colossalai.booster import Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
# from colossalai.nn.optimizer import HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS
=
[
"timm_convit"
,
"deepfm_interactionarch"
]
...
...
@@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
_STUCK_MODELS
=
[
"transformers_albert_for_multiple_choice"
]
@
clear_cache_before_run
()
def
run_fn
(
stage
,
model_fn
,
data_gen_fn
,
output_transform_fn
)
->
Optional
[
str
]:
device
=
get_accelerator
().
get_current_device
()
try
:
...
...
@@ -62,7 +63,12 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models
=
_AMP_ERR_MODELS
+
_LOW_LEVEL_ZERO_ERR_MODELS
+
_STUCK_MODELS
skipped_models
=
[]
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
model_zoo
.
items
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
# FIXME(ver217): fix these models
if
name
in
ignore_models
:
skipped_models
.
append
(
name
)
...
...
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
View file @
d66e6988
...
...
@@ -10,10 +10,11 @@ import colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
TorchDDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
@
clear_cache_before_run
()
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchDDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
@@ -40,7 +41,12 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def
check_torch_ddp_plugin
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
model_zoo
.
items
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
if
name
==
"dlrm_interactionarch"
:
continue
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
)
...
...
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
d66e6988
...
...
@@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
# test basic fsdp function
@
clear_cache_before_run
()
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
@@ -40,9 +41,20 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
optimizer
.
clip_grad_by_norm
(
1.0
)
optimizer
.
step
()
del
model
del
optimizer
del
criterion
del
booster
del
plugin
def
check_torch_fsdp_plugin
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
model_zoo
.
items
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
.
get_sub_registry
(
"transformers_gptj"
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
if
any
(
element
in
name
for
element
in
[
...
...
@@ -54,6 +66,7 @@ def check_torch_fsdp_plugin():
]
):
continue
print
(
name
)
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -68,3 +81,7 @@ def run_dist(rank, world_size, port):
@
rerun_if_address_is_in_use
()
def
test_torch_fsdp_plugin
():
spawn
(
run_dist
,
2
)
if
__name__
==
"__main__"
:
test_torch_fsdp_plugin
()
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
View file @
d66e6988
...
...
@@ -7,6 +7,7 @@ from transformers import LlamaForCausalLM
from
utils
import
shared_tempdir
import
colossalai
from
colossalai.testing
import
skip_if_not_enough_gpus
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
GeminiPlugin
from
colossalai.lazy
import
LazyInitContext
...
...
@@ -68,7 +69,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@
clear_cache_before_run
()
@
parameterize
(
"placement_config"
,
OPTIM_PLACEMENT_CONFIGS
)
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"tp_size"
,
[
1
,
2
])
@
parameterize
(
"zero_size"
,
[
2
])
...
...
@@ -156,13 +157,12 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
rerun_if_address_is_in_use
()
def
test_gemini_ckpIO
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
test_gemini_ckpIO
():
spawn
(
run_dist
,
4
)
@
pytest
.
mark
.
largedist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
8
]
)
@
skip_if_not_enough_gpus
(
min_gpus
=
8
)
@
rerun_if_address_is_in_use
()
def
test_gemini_ckpIO_3d
(
world_size
):
spawn
(
run_dist
,
world_size
)
\ No newline at end of file
def
test_gemini_ckpIO_3d
():
spawn
(
run_dist
,
8
)
\ No newline at end of file
tests/test_checkpoint_io/test_gemini_torch_compability.py
View file @
d66e6988
...
...
@@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
False
,
True
])
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
def
exam_torch_load_from_gemini
(
shard
:
bool
,
model_name
:
str
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
()))
criterion
=
lambda
x
:
x
.
mean
()
...
...
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
View file @
d66e6988
...
...
@@ -38,11 +38,11 @@ else:
]
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
TEST_CONFIGS
)
@
clear_cache_before_run
()
def
exam_state_dict
(
shard
:
bool
,
model_name
:
str
,
size_per_shard
:
int
,
test_config
:
dict
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
())
...
...
@@ -104,30 +104,32 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
# Check whether the loaded model & optimizer works smoothly.
model
.
train
()
new_model
.
train
()
data_for_shard
=
data_gen_fn
()
data_for_origin
=
data_gen_fn
()
if
booster
.
plugin
.
stage_manager
is
not
None
:
booster
.
execute_pipeline
(
_preprocess_data
(
data
),
model
,
_criterion
,
optimizer
,
return_loss
=
True
,
return_outputs
=
False
_preprocess_data
(
data
_for_shard
),
model
,
_criterion
,
optimizer
,
return_loss
=
True
,
return_outputs
=
False
)
booster
.
execute_pipeline
(
_preprocess_data
(
data
),
new_model
,
_criterion
,
new_optimizer
,
return_loss
=
True
,
return_outputs
=
False
_preprocess_data
(
data_for_origin
),
new_model
,
_criterion
,
new_optimizer
,
return_loss
=
True
,
return_outputs
=
False
,
)
else
:
old_model_loss
=
criterion
(
model
(
**
_preprocess_data
(
data
)))
old_model_loss
=
criterion
(
model
(
**
_preprocess_data
(
data
_for_shard
)))
optimizer
.
backward
(
old_model_loss
)
new_model_loss
=
criterion
(
new_model
(
**
_preprocess_data
(
data
)))
new_model_loss
=
criterion
(
new_model
(
**
_preprocess_data
(
data
_for_origin
)))
new_optimizer
.
backward
(
new_model_loss
)
optimizer
.
step
()
new_optimizer
.
step
()
# Check updated weights.
stage_manager
=
booster
.
plugin
.
stage_manager
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
assert_close_loose
(
model
.
unwrap
().
wte
.
weight
.
data
,
new_model
.
unwrap
().
wte
.
weight
.
data
,
atol
=
5e-3
,
rtol
=
5e-3
)
assert_close_loose
(
model
.
unwrap
().
h
[
0
].
mlp
.
c_fc
.
weight
.
data
,
new_model
.
unwrap
().
h
[
0
].
mlp
.
c_fc
.
weight
.
data
,
atol
=
5e-3
,
rtol
=
5e-3
)
for
p1
,
p2
in
zip
(
model
.
unwrap
().
parameters
(),
new_model
.
unwrap
().
parameters
()):
assert_close_loose
(
p1
,
p2
,
atol
=
5e-3
,
rtol
=
5e-3
)
dist
.
barrier
()
Randomizer
.
reset_index
()
...
...
@@ -145,3 +147,7 @@ def run_dist(rank, world_size, port):
@
rerun_if_address_is_in_use
()
def
test_hybrid_ckpIO
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_hybrid_ckpIO
(
4
)
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
View file @
d66e6988
...
...
@@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
@
clear_cache_before_run
()
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
@
parameterize
(
"plugin_type"
,
[
"ddp"
,
"zero"
,
"gemini"
])
def
exam_from_pretrained
(
plugin_type
:
str
,
model_name
:
str
,
shard
=
True
,
size_per_shard
=
32
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
...
...
tests/test_infer_ops/triton/kernel_utils.py
deleted
100644 → 0
View file @
9102d655
import
math
import
torch
from
torch.nn
import
functional
as
F
def
torch_context_attention
(
xq
,
xk
,
xv
,
bs
,
seqlen
,
num_head
,
head_dim
):
"""
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq
=
xq
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xk
=
xk
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xv
=
xv
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
mask
=
torch
.
tril
(
torch
.
ones
(
seqlen
,
seqlen
),
diagonal
=
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
cuda
()
mask
[
mask
==
0.0
]
=
-
100000000.0
mask
=
mask
.
repeat
(
bs
,
num_head
,
1
,
1
)
keys
=
xk
values
=
xv
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
sm_scale
=
1
/
math
.
sqrt
(
head_dim
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
*
sm_scale
scores
=
F
.
softmax
(
scores
.
float
()
+
mask
,
dim
=-
1
).
to
(
dtype
=
torch
.
float16
)
output
=
torch
.
matmul
(
scores
,
values
).
transpose
(
1
,
2
).
contiguous
().
reshape
(
-
1
,
num_head
,
head_dim
)
return
output
tests/test_infer_ops/triton/test_bloom_context_attention.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton
import
bloom_context_attn_fwd
from
tests.test_infer_ops.triton.kernel_utils
import
torch_context_attention
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_bloom_context_attention
():
bs
=
4
head_num
=
8
seq_len
=
1024
head_dim
=
64
query
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
max_input_len
=
seq_len
b_start
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
b_len
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
for
i
in
range
(
bs
):
b_start
[
i
]
=
i
*
seq_len
b_len
[
i
]
=
seq_len
o
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
alibi
=
torch
.
zeros
((
head_num
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bloom_context_attn_fwd
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
o
,
b_start
,
b_len
,
max_input_len
,
alibi
)
torch_out
=
torch_context_attention
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
bs
,
seq_len
,
head_num
,
head_dim
)
assert
torch
.
allclose
(
torch_out
.
cpu
(),
o
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-2
),
"outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_bloom_context_attention
()
tests/test_infer_ops/triton/test_copy_kv_dest.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton.copy_kv_cache_dest
import
copy_kv_cache_to_dest
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_kv_cache_copy_op
():
B_NTX
=
32
*
2048
head_num
=
8
head_dim
=
64
cache
=
torch
.
randn
((
B_NTX
,
head_num
,
head_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
dest_index
=
torch
.
arange
(
0
,
B_NTX
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
dest_data
=
torch
.
ones
((
B_NTX
,
head_num
,
head_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
copy_kv_cache_to_dest
(
cache
,
dest_index
,
dest_data
)
assert
torch
.
allclose
(
cache
.
cpu
(),
dest_data
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
),
"copy_kv_cache_to_dest outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_kv_cache_copy_op
()
tests/test_infer_ops/triton/test_layernorm_triton.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
from
colossalai.kernel.triton
import
layer_norm
from
colossalai.testing.utils
import
parameterize
try
:
pass
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
@
parameterize
(
"M"
,
[
2
,
4
,
8
,
16
])
@
parameterize
(
"N"
,
[
64
,
128
])
def
test_layer_norm
(
M
,
N
):
dtype
=
torch
.
float16
eps
=
1e-5
x_shape
=
(
M
,
N
)
w_shape
=
(
x_shape
[
-
1
],)
weight
=
torch
.
rand
(
w_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
rand
(
w_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
-
2.3
+
0.5
*
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
y_triton
=
layer_norm
(
x
,
weight
,
bias
,
eps
)
y_torch
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
w_shape
,
weight
,
bias
,
eps
).
to
(
dtype
)
assert
y_triton
.
shape
==
y_torch
.
shape
assert
y_triton
.
dtype
==
y_torch
.
dtype
print
(
"max delta: "
,
torch
.
max
(
torch
.
abs
(
y_triton
-
y_torch
)))
assert
torch
.
allclose
(
y_triton
,
y_torch
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test_layer_norm
()
tests/test_infer_ops/triton/test_llama_act_combine.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
from
colossalai.kernel.triton.llama_act_combine_kernel
import
LlamaActCombine
try
:
import
triton
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
'11.4'
)
BATCH_SIZE
=
4
SEQ_LEN
=
16
HIDDEN_SIZE
=
32
def
SwiGLU
(
x
):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size
=
x
.
shape
[
-
1
]
assert
size
%
2
==
0
,
"axis size must be divisible by 2"
x1
,
x2
=
torch
.
split
(
x
,
size
//
2
,
-
1
)
return
x1
*
(
x2
*
torch
.
sigmoid
(
x2
.
to
(
torch
.
float32
)).
to
(
x
.
dtype
))
@
pytest
.
mark
.
skipif
(
not
(
HAS_TRITON
and
TRITON_CUDA_SUPPORT
),
reason
=
"requires triton"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
])
def
test_llama_act_combine
(
dtype
:
str
):
x_gate
=
torch
.
randn
(
BATCH_SIZE
,
SEQ_LEN
,
HIDDEN_SIZE
*
2
,
dtype
=
dtype
).
cuda
()
x_gate_torch
=
nn
.
Parameter
(
x_gate
.
detach
().
clone
())
x_gate_kernel
=
nn
.
Parameter
(
x_gate
.
detach
().
clone
())
x_up
=
torch
.
randn
(
BATCH_SIZE
,
SEQ_LEN
,
HIDDEN_SIZE
,
dtype
=
dtype
).
cuda
()
x_up_torch
=
nn
.
Parameter
(
x_up
.
detach
().
clone
())
x_up_kernel
=
nn
.
Parameter
(
x_up
.
detach
().
clone
())
torch_out
=
SwiGLU
(
x_gate_torch
)
*
x_up_torch
kernel_out
=
LlamaActCombine
.
apply
(
x_gate_kernel
,
x_up_kernel
)
atol
=
1e-5
if
dtype
==
torch
.
float32
else
5e-2
assert
torch
.
allclose
(
torch_out
,
kernel_out
,
atol
=
atol
)
torch_out
.
mean
().
backward
()
kernel_out
.
mean
().
backward
()
assert
all
(
grad
is
not
None
for
grad
in
[
x_gate_torch
.
grad
,
x_up_torch
.
grad
,
x_gate_kernel
.
grad
,
x_up_kernel
.
grad
])
assert
torch
.
allclose
(
x_gate_torch
.
grad
,
x_gate_kernel
.
grad
,
atol
=
atol
)
assert
torch
.
allclose
(
x_up_torch
.
grad
,
x_up_kernel
.
grad
,
atol
=
atol
)
if
__name__
==
'__main__'
:
test_llama_act_combine
(
torch
.
float16
)
tests/test_infer_ops/triton/test_llama_context_attention.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton
import
llama_context_attn_fwd
from
tests.test_infer_ops.triton.kernel_utils
import
torch_context_attention
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_llama_context_attention
():
bs
=
4
head_num
=
8
seq_len
=
1024
head_dim
=
64
query
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
max_input_len
=
seq_len
b_start
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
b_len
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
for
i
in
range
(
bs
):
b_start
[
i
]
=
i
*
seq_len
b_len
[
i
]
=
seq_len
o
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
llama_context_attn_fwd
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
o
,
b_start
,
b_len
,
max_input_len
)
torch_out
=
torch_context_attention
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
bs
,
seq_len
,
head_num
,
head_dim
)
assert
torch
.
allclose
(
torch_out
.
cpu
(),
o
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
),
"outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_llama_context_attention
()
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
import
torch.nn.functional
as
F
from
packaging
import
version
try
:
import
triton
from
colossalai.kernel.triton.qkv_matmul_kernel
import
qkv_gemm_4d_kernel
from
colossalai.kernel.triton.self_attention_nofusion
import
self_attention_compute_using_triton
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_qkv_matmul
():
qkv
=
torch
.
randn
((
4
,
24
,
64
*
3
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
scale
=
1.2
head_size
=
32
batches
=
qkv
.
shape
[
0
]
d_model
=
qkv
.
shape
[
-
1
]
//
3
num_of_heads
=
d_model
//
head_size
q
=
qkv
[:,
:,
:
d_model
]
k
=
qkv
[:,
:,
d_model
:
d_model
*
2
]
q
=
q
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
k
=
k
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
q_copy
=
q
.
clone
()
k_copy
=
k
.
clone
()
q
=
torch
.
transpose
(
q
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
2
,
3
).
contiguous
()
torch_ouput
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
q
,
k
)
torch_ouput
*=
1.2
q
,
k
=
q_copy
,
k_copy
batches
,
M
,
H
,
K
=
q
.
shape
N
=
k
.
shape
[
1
]
score_output
=
torch
.
empty
((
batches
,
H
,
M
,
N
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
grid
=
lambda
meta
:
(
batches
,
H
,
triton
.
cdiv
(
M
,
meta
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
meta
[
"BLOCK_SIZE_N"
]),
)
K
=
q
.
shape
[
3
]
qkv_gemm_4d_kernel
[
grid
](
q
,
k
,
score_output
,
M
,
N
,
K
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
3
),
k
.
stride
(
1
),
score_output
.
stride
(
0
),
score_output
.
stride
(
1
),
score_output
.
stride
(
2
),
score_output
.
stride
(
3
),
scale
=
scale
,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M
=
64
,
BLOCK_SIZE_N
=
32
,
BLOCK_SIZE_K
=
32
,
GROUP_SIZE_M
=
8
,
)
check
=
torch
.
allclose
(
torch_ouput
.
cpu
(),
score_output
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-5
)
assert
check
is
True
,
"the outputs of triton and torch are not matched"
def
self_attention_compute_using_torch
(
qkv
,
input_mask
,
scale
,
head_size
):
batches
=
qkv
.
shape
[
0
]
d_model
=
qkv
.
shape
[
-
1
]
//
3
num_of_heads
=
d_model
//
head_size
q
=
qkv
[:,
:,
:
d_model
]
k
=
qkv
[:,
:,
d_model
:
d_model
*
2
]
v
=
qkv
[:,
:,
d_model
*
2
:]
q
=
q
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
k
=
k
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
v
=
v
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
q
=
torch
.
transpose
(
q
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
1
,
2
).
contiguous
()
v
=
torch
.
transpose
(
v
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
-
1
,
-
2
).
contiguous
()
score_output
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
q
,
k
)
score_output
*=
scale
softmax_output
=
F
.
softmax
(
score_output
,
dim
=-
1
)
res
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
softmax_output
,
v
)
res
=
torch
.
transpose
(
res
,
1
,
2
)
res
=
res
.
contiguous
()
return
res
.
view
(
batches
,
-
1
,
d_model
),
score_output
,
softmax_output
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_self_atttention_test
():
qkv
=
torch
.
randn
((
4
,
24
,
64
*
3
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
data_output_torch
,
score_output_torch
,
softmax_output_torch
=
self_attention_compute_using_torch
(
qkv
.
clone
(),
input_mask
=
None
,
scale
=
1.2
,
head_size
=
32
)
data_output_triton
=
self_attention_compute_using_triton
(
qkv
.
clone
(),
alibi
=
None
,
head_size
=
32
,
scale
=
1.2
,
input_mask
=
None
,
layer_past
=
None
,
use_flash
=
False
,
triangular
=
True
,
)
check
=
torch
.
allclose
(
data_output_triton
.
cpu
(),
data_output_torch
.
cpu
(),
rtol
=
1e-4
,
atol
=
1e-2
)
assert
check
is
True
,
"the triton output is not matched with torch output"
if
__name__
==
"__main__"
:
test_qkv_matmul
()
test_self_atttention_test
()
tests/test_infer_ops/triton/test_softmax.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
try
:
from
colossalai.kernel.triton.softmax
import
softmax
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_softmax_op
():
data_samples
=
[
torch
.
randn
((
3
,
4
,
5
,
32
),
device
=
"cuda"
,
dtype
=
torch
.
float32
),
torch
.
randn
((
320
,
320
,
78
),
device
=
"cuda"
,
dtype
=
torch
.
float32
),
torch
.
randn
((
2345
,
4
,
5
,
64
),
device
=
"cuda"
,
dtype
=
torch
.
float16
),
]
for
data
in
data_samples
:
module
=
nn
.
Softmax
(
dim
=-
1
)
data_torch_out
=
module
(
data
)
data_triton_out
=
softmax
(
data
)
check
=
torch
.
allclose
(
data_torch_out
.
cpu
(),
data_triton_out
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
)
assert
check
is
True
,
"softmax outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_softmax_op
()
tests/test_infer_ops/triton/test_token_attn_fwd.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
from
colossalai.kernel.triton.token_attention_kernel
import
token_attention_fwd
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
import
importlib.util
HAS_LIGHTLLM_KERNEL
=
True
if
importlib
.
util
.
find_spec
(
"lightllm"
)
is
None
:
HAS_LIGHTLLM_KERNEL
=
False
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>=
version
.
parse
(
"11.6"
)
def
torch_att
(
xq
,
xk
,
xv
,
bs
,
seqlen
,
num_head
,
head_dim
):
xq
=
xq
.
view
(
bs
,
1
,
num_head
,
head_dim
)
xk
=
xk
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xv
=
xv
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
logics
=
torch
.
sum
(
xq
*
xk
,
dim
=
3
,
keepdim
=
False
)
*
1
/
(
head_dim
**
0.5
)
prob
=
torch
.
softmax
(
logics
,
dim
=
1
)
prob
=
prob
.
view
(
bs
,
seqlen
,
num_head
,
1
)
return
torch
.
sum
(
prob
*
xv
,
dim
=
1
,
keepdim
=
False
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
or
not
HAS_LIGHTLLM_KERNEL
,
reason
=
"triton requires cuda version to be higher than 11.4 or not install lightllm"
,
)
def
test
():
Z
,
head_num
,
seq_len
,
head_dim
=
22
,
112
//
8
,
2048
,
128
dtype
=
torch
.
float16
q
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
v
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
o
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
alibi
=
torch
.
zeros
((
head_num
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
max_kv_cache_len
=
seq_len
kv_cache_start_loc
=
torch
.
zeros
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_loc
=
torch
.
zeros
((
Z
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
=
torch
.
ones
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
[:]
=
seq_len
kv_cache_start_loc
[
0
]
=
0
kv_cache_start_loc
[
1
]
=
seq_len
kv_cache_start_loc
[
2
]
=
2
*
seq_len
kv_cache_start_loc
[
3
]
=
3
*
seq_len
for
i
in
range
(
Z
):
kv_cache_loc
[
i
,
:]
=
torch
.
arange
(
i
*
seq_len
,
(
i
+
1
)
*
seq_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_attention_fwd
(
q
,
k
,
v
,
o
,
kv_cache_loc
,
kv_cache_start_loc
,
kv_cache_seq_len
,
max_kv_cache_len
,
alibi
=
alibi
)
torch_out
=
torch_att
(
q
,
k
,
v
,
Z
,
seq_len
,
head_num
,
head_dim
)
print
(
"max "
,
torch
.
max
(
torch
.
abs
(
torch_out
-
o
)))
print
(
"mean "
,
torch
.
mean
(
torch
.
abs
(
torch_out
-
o
)))
assert
torch
.
allclose
(
torch_out
,
o
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test
()
tests/test_infer_ops/triton/test_token_softmax.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton.token_attention_kernel
import
token_attn_softmax_fwd
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_softmax
():
import
torch
batch_size
,
seq_len
,
head_num
,
head_dim
=
4
,
1025
,
12
,
128
dtype
=
torch
.
float16
Logics
=
torch
.
empty
((
head_num
,
batch_size
*
seq_len
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
10
)
ProbOut
=
torch
.
empty
((
head_num
,
batch_size
*
seq_len
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
kv_cache_start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
batch_size
):
kv_cache_start_loc
[
i
]
=
i
*
seq_len
kv_cache_seq_len
[
i
]
=
seq_len
token_attn_softmax_fwd
(
Logics
,
kv_cache_start_loc
,
kv_cache_seq_len
,
ProbOut
,
seq_len
)
torch_out
=
Logics
.
reshape
(
head_num
*
batch_size
,
-
1
).
softmax
(
-
1
).
reshape
(
head_num
,
batch_size
*
seq_len
)
o
=
ProbOut
print
(
"max "
,
torch
.
max
(
torch
.
abs
(
torch_out
-
o
)))
print
(
"mean "
,
torch
.
mean
(
torch
.
abs
(
torch_out
-
o
)))
assert
torch
.
allclose
(
torch_out
,
o
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test_softmax
()
tests/test_lazy/test_models.py
View file @
d66e6988
import
pytest
from
lazy_init_utils
import
SUPPORT_LAZY
,
check_lazy_init
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
@
pytest
.
mark
.
skipif
(
not
SUPPORT_LAZY
,
reason
=
"requires torch >= 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"subset"
,
[
"torchvision"
,
"diffusers"
,
"timm"
,
"transformers"
,
"torchaudio"
,
"deepfm"
,
"dlrm"
])
@
pytest
.
mark
.
parametrize
(
"subset"
,
[
COMMON_MODELS
]
if
IS_FAST_TEST
else
[
"torchvision"
,
"diffusers"
,
"timm"
,
"transformers"
,
"torchaudio"
,
"deepfm"
,
"dlrm"
],
)
@
pytest
.
mark
.
parametrize
(
"default_device"
,
[
"cpu"
,
"cuda"
])
def
test_torchvision_models_lazy_init
(
subset
,
default_device
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
,
allow_empty
=
True
)
for
name
,
entry
in
sub_model_zoo
.
items
():
# TODO(ver217): lazy init does not support weight norm, skip these models
if
name
in
(
"torchaudio_wav2vec2_base"
,
"torchaudio_hubert_base"
)
or
name
.
startswith
(
...
...
tests/test_pipeline/test_p2p_communication.py
View file @
d66e6988
...
...
@@ -5,43 +5,69 @@ import torch.distributed as dist
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
,
create_send_metadata
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
WORLD_SIZE
=
2
def
check_p2p_communication
():
pg_mesh
=
ProcessGroupMesh
(
2
)
pg_mesh
=
ProcessGroupMesh
(
WORLD_SIZE
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
0
)
p2p
=
PipelineP2PCommunication
(
stage_manager
)
rank
=
dist
.
get_rank
()
tensor
=
torch
.
ones
(
1
,
device
=
get_accelerator
().
get_current_device
())
data
=
[
"tensor"
,
tensor
,
[
tensor
],
{
"tensor"
:
tensor
},
]
if
rank
==
0
:
p2p
.
send_forward
(
tensor
)
p2p
.
send_forward
([
tensor
])
p2p
.
send_forward
({
"tensor"
:
tensor
})
else
:
obj
=
p2p
.
recv_forward
()
assert
torch
.
equal
(
obj
,
tensor
)
obj
=
p2p
.
recv_forward
()
assert
type
(
obj
)
==
list
and
len
(
obj
)
==
1
and
torch
.
equal
(
obj
[
0
],
tensor
)
obj
=
p2p
.
recv_forward
()
assert
type
(
obj
)
==
dict
and
"tensor"
in
obj
and
torch
.
equal
(
obj
[
"tensor"
],
tensor
)
for
obj
in
data
:
p2p
.
send_forward
(
obj
)
for
i
in
range
(
len
(
data
)):
recv_obj
=
p2p
.
send_forward_recv_backward
(
data
[
i
],
send_prior_fallback
=
False
)
assert
recv_obj
==
data
[
-
(
i
+
1
)]
elif
rank
==
1
:
for
obj
in
data
:
recv_obj
=
p2p
.
recv_forward
()
assert
recv_obj
==
obj
for
i
in
range
(
len
(
data
)):
p2p
.
send_backward
(
data
[
-
(
i
+
1
)])
recv_obj
=
p2p
.
recv_forward
()
assert
recv_obj
==
data
[
i
]
if
rank
==
1
:
p2p
.
send_backward
(
tensor
)
p2p
.
send_backward
([
tensor
])
p2p
.
send_backward
({
"tensor"
:
tensor
})
else
:
obj
=
p2p
.
recv_backward
()
assert
torch
.
equal
(
obj
,
tensor
)
obj
=
p2p
.
recv_backward
()
assert
type
(
obj
)
==
list
and
len
(
obj
)
==
1
and
torch
.
equal
(
obj
[
0
],
tensor
)
obj
=
p2p
.
recv_backward
()
assert
type
(
obj
)
==
dict
and
"tensor"
in
obj
and
torch
.
equal
(
obj
[
"tensor"
],
tensor
)
for
obj
in
data
:
p2p
.
send_backward
(
obj
)
for
i
in
range
(
len
(
data
)):
recv_obj
=
p2p
.
send_backward_recv_forward
(
data
[
i
],
send_prior_fallback
=
True
)
assert
recv_obj
==
data
[
-
(
i
+
1
)]
elif
rank
==
0
:
for
obj
in
data
:
recv_obj
=
p2p
.
recv_backward
()
assert
recv_obj
==
obj
for
i
in
range
(
len
(
data
)):
recv_obj
=
p2p
.
recv_backward
()
p2p
.
send_forward
(
data
[
-
(
i
+
1
)])
assert
recv_obj
==
data
[
i
]
if
rank
==
0
:
recv_obj
=
p2p
.
send_forward_recv_backward
(
tensor
,
send_metadata
=
False
,
metadata_recv
=
create_send_metadata
(
tensor
),
)
assert
recv_obj
==
tensor
elif
rank
==
1
:
recv_obj
=
p2p
.
recv_forward
(
metadata_recv
=
create_send_metadata
(
tensor
))
assert
recv_obj
==
tensor
p2p
.
send_backward
(
tensor
,
send_metadata
=
False
)
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -52,7 +78,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_pipeline_p2p
():
spawn
(
run_dist
,
2
)
spawn
(
run_dist
,
WORLD_SIZE
)
if
__name__
==
"__main__"
:
...
...
tests/test_pipeline/test_schedule/test_interleaved.py
View file @
d66e6988
...
...
@@ -4,6 +4,7 @@ from types import MethodType
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
colossalai
...
...
@@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.pipeline.schedule.interleaved_pp
import
InterleavedSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
NUM_LAYER
=
8
DIM
=
4
class
MlpModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MlpModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
4
,
8
)
self
.
linear2
=
nn
.
Linear
(
8
,
8
)
self
.
linear3
=
nn
.
Linear
(
8
,
8
)
self
.
linear4
=
nn
.
Linear
(
8
,
8
)
self
.
linear5
=
nn
.
Linear
(
8
,
8
)
self
.
linear6
=
nn
.
Linear
(
8
,
8
)
self
.
linear7
=
nn
.
Linear
(
8
,
8
)
self
.
linear8
=
nn
.
Linear
(
8
,
4
)
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
DIM
,
DIM
)
for
_
in
range
(
NUM_LAYER
)])
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
self
.
linear4
(
x
)
x
=
self
.
linear5
(
x
)
x
=
self
.
linear6
(
x
)
x
=
self
.
linear7
(
x
)
x
=
self
.
linear8
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
...
...
@@ -44,70 +35,72 @@ def pp_linear_fwd(
data
:
torch
.
Tensor
=
None
,
input_obj
:
torch
.
Tensor
=
None
,
stage_mgr
:
PipelineStageManager
=
None
,
num_chunks
:
int
=
None
,
model_chunk_id
:
int
=
None
,
):
if
stage_mgr
.
is_first_stage
()
and
model_chunk_id
==
0
:
with
stage_mgr
.
switch_model_chunk_id
(
model_chunk_id
):
if
stage_mgr
.
is_first_stage
():
return
{
"input_obj"
:
forward
(
data
)}
elif
stage_mgr
.
is_last_stage
()
and
model_chunk_id
==
num_chunks
-
1
:
elif
stage_mgr
.
is_last_stage
():
return
forward
(
input_obj
)
else
:
return
{
"input_obj"
:
forward
(
input_obj
)}
@
parameterize
(
"num_micro_batches"
,
[
4
,
8
,
12
])
def
examine_pp
(
num_micro_batches
):
def
run_pp
(
rank
:
int
,
world_size
:
int
,
port
:
int
,
num_microbatch
:
int
,
batch_size
:
int
,
num_model_chunk
:
int
,
):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size
=
torch
.
distributed
.
get_world_size
()
local_rank
=
torch
.
distributed
.
get_rank
()
seed_all
(
1453
)
NUM_MICRO_BATCHS
=
num_micro_batches
BATCH_SIZE
=
num_micro_batches
NUM_CHUNKS
=
2
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
"localhost"
)
# create model
seed_all
(
1453
)
torch_model
=
MlpModel
().
cuda
()
pp_model
=
copy
.
deepcopy
(
torch_model
).
cuda
()
DP_DIM
,
PP_DIM
,
TP_DIM
=
0
,
1
,
2
pg_mesh
=
ProcessGroupMesh
(
1
,
world_size
,
1
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
PP_DIM
,
is_virtual
=
True
)
schedule
=
InterleavedSchedule
(
NUM_MICRO_BATCHS
,
NUM_CHUNKS
,
stage_manager
)
pg_mesh
=
ProcessGroupMesh
(
world_size
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
pipeline_axis
=
0
,
enable_interleave
=
True
,
num_model_chunks
=
num_model_chunk
)
schedule
=
InterleavedSchedule
(
stage_manager
=
stage_manager
,
num_model_chunks
=
num_model_chunk
,
num_microbatch
=
num_microbatch
,
)
sharded_model
=
torch
.
nn
.
ModuleList
()
for
idx
,
(
_
,
sub_model
)
in
enumerate
(
pp_model
.
named_children
()
):
if
idx
%
(
world_size
)
==
local_
rank
:
for
idx
,
sub_model
in
enumerate
(
pp_model
.
layers
):
if
idx
%
world_size
==
rank
:
sub_model
.
_forward
=
sub_model
.
forward
sub_model
.
forward
=
MethodType
(
partial
(
pp_linear_fwd
,
stage_mgr
=
stage_manager
,
num_chunks
=
NUM_CHUNKS
,
model_chunk_id
=
len
(
sharded_model
)
),
partial
(
pp_linear_fwd
,
stage_mgr
=
stage_manager
,
model_chunk_id
=
len
(
sharded_model
)),
sub_model
.
_forward
,
)
sharded_model
.
append
(
sub_model
.
cuda
())
assert
len
(
sharded_model
)
==
num_model_chunk
,
"num_model_chunk is not correct"
# create optimizer
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
)
pp_optimizer
=
OptimizerWrapper
(
torch
.
optim
.
SGD
(
sharded_model
.
parameters
(),
lr
=
1
))
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
e-5
)
pp_optimizer
=
OptimizerWrapper
(
torch
.
optim
.
SGD
(
sharded_model
.
parameters
(),
lr
=
1
e-5
))
# create
seed_all
(
1453
)
if
local_rank
==
0
:
input_list
=
[
torch
.
rand
(
BATCH_SIZE
,
4
).
cuda
()]
else
:
input_list
=
[
torch
.
zeros
(
BATCH_SIZE
,
4
).
cuda
()]
torch
.
distributed
.
all_reduce
(
input_list
[
0
])
# create data
seed_all
(
115
)
input_list
=
[
torch
.
rand
(
batch_size
,
DIM
).
cuda
()]
dist
.
all_reduce
(
input_list
[
0
])
criterion
=
lambda
x
,
y
:
torch
.
mean
(
x
)
def
criterion
(
x
,
*
args
,
**
kwargs
):
return
(
x
*
x
).
mean
()
# forward and backward
torch_output
=
torch_model
(
input_list
[
0
])
torch_loss
=
criterion
(
torch_output
,
_
)
torch_loss
=
criterion
(
torch_output
)
torch_loss
.
backward
()
pp_ret
=
schedule
.
forward_backward_step
(
...
...
@@ -115,45 +108,60 @@ def examine_pp(num_micro_batches):
)
# check loss
if
stage_manager
.
is_last_stage
():
if
stage_manager
.
is_last_stage
(
ignore_chunk
=
True
):
assert
torch
.
allclose
(
torch_loss
,
pp_ret
[
"loss"
])
# check gradients
torch_grad
=
[]
for
torch_p
in
torch_model
.
parameters
():
torch_grad
.
append
(
torch_p
.
grad
.
data
)
for
idx
,
pp_p
in
enumerate
(
sharded_model
.
parameters
()):
if
idx
<
2
:
assert
torch
.
allclose
(
torch_grad
[
idx
+
local_rank
*
2
],
pp_p
.
grad
.
data
)
else
:
assert
torch
.
allclose
(
torch_grad
[
idx
+
local_rank
*
2
+
6
],
pp_p
.
grad
.
data
)
for
i
in
range
(
num_model_chunk
):
idx
=
world_size
*
i
+
rank
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
weight
.
grad
,
sharded_model
[
i
].
weight
.
grad
)
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
bias
.
grad
,
sharded_model
[
i
].
bias
.
grad
)
# step
torch_optimizer
.
step
()
pp_optimizer
.
step
()
pp_optimizer
.
zero_grad
()
# check updated param
torch_param
=
[]
for
torch_p
in
torch_model
.
parameters
():
torch_param
.
append
(
torch_p
.
data
)
for
idx
,
pp_p
in
enumerate
(
sharded_model
.
parameters
()):
if
idx
<
2
:
assert
torch
.
allclose
(
torch_param
[
idx
+
local_rank
*
2
],
pp_p
.
data
)
else
:
assert
torch
.
allclose
(
torch_param
[
idx
+
local_rank
*
2
+
6
],
pp_p
.
data
)
for
i
in
range
(
num_model_chunk
):
idx
=
world_size
*
i
+
rank
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
weight
,
sharded_model
[
i
].
weight
)
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
bias
,
sharded_model
[
i
].
bias
)
# forward only
with
torch
.
no_grad
():
torch_output
=
torch_model
(
input_list
[
0
])
torch_loss
=
criterion
(
torch_output
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
"localhost"
)
examine_pp
()
pp_ret
=
schedule
.
forward_backward_step
(
sharded_model
,
iter
(
input_list
),
criterion
,
pp_optimizer
,
return_loss
=
True
,
return_outputs
=
True
)
if
stage_manager
.
is_last_stage
(
ignore_chunk
=
True
):
assert
torch
.
allclose
(
torch_loss
,
pp_ret
[
"loss"
])
for
layer
in
sharded_model
:
if
layer
.
weight
.
grad
is
None
:
assert
layer
.
weight
.
grad
is
None
and
layer
.
bias
.
grad
is
None
else
:
assert
torch
.
allclose
(
layer
.
weight
.
grad
,
torch
.
zeros_like
(
layer
.
weight
.
grad
))
assert
torch
.
allclose
(
layer
.
bias
.
grad
,
torch
.
zeros_like
(
layer
.
bias
.
grad
))
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"num_microbatch"
,
[
4
,
12
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
12
])
@
pytest
.
mark
.
parametrize
(
"num_model_chunk"
,
[
2
,
4
])
@
rerun_if_address_is_in_use
()
def
test_pp
():
spawn
(
run_dist
,
4
)
def
test_pp
(
num_microbatch
:
int
,
batch_size
:
int
,
num_model_chunk
:
int
):
assert
NUM_LAYER
%
num_model_chunk
==
0
spawn
(
run_pp
,
nprocs
=
NUM_LAYER
//
num_model_chunk
,
num_microbatch
=
num_microbatch
,
batch_size
=
batch_size
,
num_model_chunk
=
num_model_chunk
,
)
if
__name__
==
"__main__"
:
test_pp
()
test_pp
(
num_microbatch
=
4
,
batch_size
=
4
,
num_model_chunk
=
4
)
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment