Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
14846934
Commit
14846934
authored
Jan 18, 2024
by
ver217
Browse files
Merge branch 'main' into sync/npu
parents
9102d655
5d9a0ae7
Changes
152
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
212 additions
and
704 deletions
+212
-704
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+9
-3
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
+9
-3
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+20
-3
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+7
-7
tests/test_checkpoint_io/test_gemini_torch_compability.py
tests/test_checkpoint_io/test_gemini_torch_compability.py
+1
-1
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
...heckpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+19
-13
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
...t_checkpoint_io/test_plugins_huggingface_compatibility.py
+1
-1
tests/test_infer_ops/triton/kernel_utils.py
tests/test_infer_ops/triton/kernel_utils.py
+0
-27
tests/test_infer_ops/triton/test_bloom_context_attention.py
tests/test_infer_ops/triton/test_bloom_context_attention.py
+0
-52
tests/test_infer_ops/triton/test_copy_kv_dest.py
tests/test_infer_ops/triton/test_copy_kv_dest.py
+0
-39
tests/test_infer_ops/triton/test_layernorm_triton.py
tests/test_infer_ops/triton/test_layernorm_triton.py
+0
-43
tests/test_infer_ops/triton/test_llama_act_combine.py
tests/test_infer_ops/triton/test_llama_act_combine.py
+0
-56
tests/test_infer_ops/triton/test_llama_context_attention.py
tests/test_infer_ops/triton/test_llama_context_attention.py
+0
-50
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
+0
-143
tests/test_infer_ops/triton/test_softmax.py
tests/test_infer_ops/triton/test_softmax.py
+0
-36
tests/test_infer_ops/triton/test_token_attn_fwd.py
tests/test_infer_ops/triton/test_token_attn_fwd.py
+0
-72
tests/test_infer_ops/triton/test_token_softmax.py
tests/test_infer_ops/triton/test_token_softmax.py
+0
-48
tests/test_lazy/test_models.py
tests/test_lazy/test_models.py
+8
-3
tests/test_pipeline/test_p2p_communication.py
tests/test_pipeline/test_p2p_communication.py
+49
-23
tests/test_pipeline/test_schedule/test_interleaved.py
tests/test_pipeline/test_schedule/test_interleaved.py
+89
-81
No files found.
tests/test_booster/test_plugin/test_low_level_zero_plugin.py
View file @
14846934
...
...
@@ -10,8 +10,8 @@ from colossalai.booster import Booster
from
colossalai.booster.plugin
import
LowLevelZeroPlugin
# from colossalai.nn.optimizer import HybridAdam
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS
=
[
"timm_convit"
,
"deepfm_interactionarch"
]
...
...
@@ -21,6 +21,7 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
_STUCK_MODELS
=
[
"transformers_albert_for_multiple_choice"
]
@
clear_cache_before_run
()
def
run_fn
(
stage
,
model_fn
,
data_gen_fn
,
output_transform_fn
)
->
Optional
[
str
]:
device
=
get_accelerator
().
get_current_device
()
try
:
...
...
@@ -62,7 +63,12 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models
=
_AMP_ERR_MODELS
+
_LOW_LEVEL_ZERO_ERR_MODELS
+
_STUCK_MODELS
skipped_models
=
[]
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
model_zoo
.
items
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
# FIXME(ver217): fix these models
if
name
in
ignore_models
:
skipped_models
.
append
(
name
)
...
...
tests/test_booster/test_plugin/test_torch_ddp_plugin.py
View file @
14846934
...
...
@@ -10,10 +10,11 @@ import colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
TorchDDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
@
clear_cache_before_run
()
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchDDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
@@ -40,7 +41,12 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
def
check_torch_ddp_plugin
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
model_zoo
.
items
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
if
name
==
"dlrm_interactionarch"
:
continue
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
)
...
...
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
14846934
...
...
@@ -11,11 +11,12 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"):
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
colossalai.testing
import
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
# test basic fsdp function
@
clear_cache_before_run
()
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
@@ -40,9 +41,20 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
optimizer
.
clip_grad_by_norm
(
1.0
)
optimizer
.
step
()
del
model
del
optimizer
del
criterion
del
booster
del
plugin
def
check_torch_fsdp_plugin
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
model_zoo
.
items
():
if
IS_FAST_TEST
:
registry
=
model_zoo
.
get_sub_registry
(
COMMON_MODELS
)
else
:
registry
=
model_zoo
.
get_sub_registry
(
"transformers_gptj"
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
in
registry
.
items
():
if
any
(
element
in
name
for
element
in
[
...
...
@@ -54,6 +66,7 @@ def check_torch_fsdp_plugin():
]
):
continue
print
(
name
)
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -68,3 +81,7 @@ def run_dist(rank, world_size, port):
@
rerun_if_address_is_in_use
()
def
test_torch_fsdp_plugin
():
spawn
(
run_dist
,
2
)
if
__name__
==
"__main__"
:
test_torch_fsdp_plugin
()
tests/test_checkpoint_io/test_gemini_checkpoint_io.py
View file @
14846934
...
...
@@ -7,6 +7,7 @@ from transformers import LlamaForCausalLM
from
utils
import
shared_tempdir
import
colossalai
from
colossalai.testing
import
skip_if_not_enough_gpus
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
GeminiPlugin
from
colossalai.lazy
import
LazyInitContext
...
...
@@ -68,7 +69,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@
clear_cache_before_run
()
@
parameterize
(
"placement_config"
,
OPTIM_PLACEMENT_CONFIGS
)
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"tp_size"
,
[
1
,
2
])
@
parameterize
(
"zero_size"
,
[
2
])
...
...
@@ -156,13 +157,12 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
rerun_if_address_is_in_use
()
def
test_gemini_ckpIO
(
world_size
):
spawn
(
run_dist
,
world_size
)
def
test_gemini_ckpIO
():
spawn
(
run_dist
,
4
)
@
pytest
.
mark
.
largedist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
8
]
)
@
skip_if_not_enough_gpus
(
min_gpus
=
8
)
@
rerun_if_address_is_in_use
()
def
test_gemini_ckpIO_3d
(
world_size
):
spawn
(
run_dist
,
world_size
)
\ No newline at end of file
def
test_gemini_ckpIO_3d
():
spawn
(
run_dist
,
8
)
\ No newline at end of file
tests/test_checkpoint_io/test_gemini_torch_compability.py
View file @
14846934
...
...
@@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
False
,
True
])
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
def
exam_torch_load_from_gemini
(
shard
:
bool
,
model_name
:
str
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
()))
criterion
=
lambda
x
:
x
.
mean
()
...
...
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
View file @
14846934
...
...
@@ -38,11 +38,11 @@ else:
]
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
TEST_CONFIGS
)
@
clear_cache_before_run
()
def
exam_state_dict
(
shard
:
bool
,
model_name
:
str
,
size_per_shard
:
int
,
test_config
:
dict
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
())
...
...
@@ -104,30 +104,32 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
# Check whether the loaded model & optimizer works smoothly.
model
.
train
()
new_model
.
train
()
data_for_shard
=
data_gen_fn
()
data_for_origin
=
data_gen_fn
()
if
booster
.
plugin
.
stage_manager
is
not
None
:
booster
.
execute_pipeline
(
_preprocess_data
(
data
),
model
,
_criterion
,
optimizer
,
return_loss
=
True
,
return_outputs
=
False
_preprocess_data
(
data
_for_shard
),
model
,
_criterion
,
optimizer
,
return_loss
=
True
,
return_outputs
=
False
)
booster
.
execute_pipeline
(
_preprocess_data
(
data
),
new_model
,
_criterion
,
new_optimizer
,
return_loss
=
True
,
return_outputs
=
False
_preprocess_data
(
data_for_origin
),
new_model
,
_criterion
,
new_optimizer
,
return_loss
=
True
,
return_outputs
=
False
,
)
else
:
old_model_loss
=
criterion
(
model
(
**
_preprocess_data
(
data
)))
old_model_loss
=
criterion
(
model
(
**
_preprocess_data
(
data
_for_shard
)))
optimizer
.
backward
(
old_model_loss
)
new_model_loss
=
criterion
(
new_model
(
**
_preprocess_data
(
data
)))
new_model_loss
=
criterion
(
new_model
(
**
_preprocess_data
(
data
_for_origin
)))
new_optimizer
.
backward
(
new_model_loss
)
optimizer
.
step
()
new_optimizer
.
step
()
# Check updated weights.
stage_manager
=
booster
.
plugin
.
stage_manager
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
assert_close_loose
(
model
.
unwrap
().
wte
.
weight
.
data
,
new_model
.
unwrap
().
wte
.
weight
.
data
,
atol
=
5e-3
,
rtol
=
5e-3
)
assert_close_loose
(
model
.
unwrap
().
h
[
0
].
mlp
.
c_fc
.
weight
.
data
,
new_model
.
unwrap
().
h
[
0
].
mlp
.
c_fc
.
weight
.
data
,
atol
=
5e-3
,
rtol
=
5e-3
)
for
p1
,
p2
in
zip
(
model
.
unwrap
().
parameters
(),
new_model
.
unwrap
().
parameters
()):
assert_close_loose
(
p1
,
p2
,
atol
=
5e-3
,
rtol
=
5e-3
)
dist
.
barrier
()
Randomizer
.
reset_index
()
...
...
@@ -145,3 +147,7 @@ def run_dist(rank, world_size, port):
@
rerun_if_address_is_in_use
()
def
test_hybrid_ckpIO
(
world_size
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_hybrid_ckpIO
(
4
)
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
View file @
14846934
...
...
@@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
@
clear_cache_before_run
()
@
parameterize
(
"model_name"
,
[
"transformers_
gpt
"
])
@
parameterize
(
"model_name"
,
[
"transformers_
llama_for_casual_lm
"
])
@
parameterize
(
"plugin_type"
,
[
"ddp"
,
"zero"
,
"gemini"
])
def
exam_from_pretrained
(
plugin_type
:
str
,
model_name
:
str
,
shard
=
True
,
size_per_shard
=
32
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
...
...
tests/test_infer_ops/triton/kernel_utils.py
deleted
100644 → 0
View file @
9102d655
import
math
import
torch
from
torch.nn
import
functional
as
F
def
torch_context_attention
(
xq
,
xk
,
xv
,
bs
,
seqlen
,
num_head
,
head_dim
):
"""
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
"""
xq
=
xq
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xk
=
xk
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xv
=
xv
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
mask
=
torch
.
tril
(
torch
.
ones
(
seqlen
,
seqlen
),
diagonal
=
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
cuda
()
mask
[
mask
==
0.0
]
=
-
100000000.0
mask
=
mask
.
repeat
(
bs
,
num_head
,
1
,
1
)
keys
=
xk
values
=
xv
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
sm_scale
=
1
/
math
.
sqrt
(
head_dim
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
*
sm_scale
scores
=
F
.
softmax
(
scores
.
float
()
+
mask
,
dim
=-
1
).
to
(
dtype
=
torch
.
float16
)
output
=
torch
.
matmul
(
scores
,
values
).
transpose
(
1
,
2
).
contiguous
().
reshape
(
-
1
,
num_head
,
head_dim
)
return
output
tests/test_infer_ops/triton/test_bloom_context_attention.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton
import
bloom_context_attn_fwd
from
tests.test_infer_ops.triton.kernel_utils
import
torch_context_attention
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_bloom_context_attention
():
bs
=
4
head_num
=
8
seq_len
=
1024
head_dim
=
64
query
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
max_input_len
=
seq_len
b_start
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
b_len
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
for
i
in
range
(
bs
):
b_start
[
i
]
=
i
*
seq_len
b_len
[
i
]
=
seq_len
o
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
alibi
=
torch
.
zeros
((
head_num
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bloom_context_attn_fwd
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
o
,
b_start
,
b_len
,
max_input_len
,
alibi
)
torch_out
=
torch_context_attention
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
bs
,
seq_len
,
head_num
,
head_dim
)
assert
torch
.
allclose
(
torch_out
.
cpu
(),
o
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-2
),
"outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_bloom_context_attention
()
tests/test_infer_ops/triton/test_copy_kv_dest.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton.copy_kv_cache_dest
import
copy_kv_cache_to_dest
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_kv_cache_copy_op
():
B_NTX
=
32
*
2048
head_num
=
8
head_dim
=
64
cache
=
torch
.
randn
((
B_NTX
,
head_num
,
head_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
dest_index
=
torch
.
arange
(
0
,
B_NTX
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
dest_data
=
torch
.
ones
((
B_NTX
,
head_num
,
head_dim
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
copy_kv_cache_to_dest
(
cache
,
dest_index
,
dest_data
)
assert
torch
.
allclose
(
cache
.
cpu
(),
dest_data
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
),
"copy_kv_cache_to_dest outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_kv_cache_copy_op
()
tests/test_infer_ops/triton/test_layernorm_triton.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
from
colossalai.kernel.triton
import
layer_norm
from
colossalai.testing.utils
import
parameterize
try
:
pass
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
@
parameterize
(
"M"
,
[
2
,
4
,
8
,
16
])
@
parameterize
(
"N"
,
[
64
,
128
])
def
test_layer_norm
(
M
,
N
):
dtype
=
torch
.
float16
eps
=
1e-5
x_shape
=
(
M
,
N
)
w_shape
=
(
x_shape
[
-
1
],)
weight
=
torch
.
rand
(
w_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
rand
(
w_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
-
2.3
+
0.5
*
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
y_triton
=
layer_norm
(
x
,
weight
,
bias
,
eps
)
y_torch
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
w_shape
,
weight
,
bias
,
eps
).
to
(
dtype
)
assert
y_triton
.
shape
==
y_torch
.
shape
assert
y_triton
.
dtype
==
y_torch
.
dtype
print
(
"max delta: "
,
torch
.
max
(
torch
.
abs
(
y_triton
-
y_torch
)))
assert
torch
.
allclose
(
y_triton
,
y_torch
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test_layer_norm
()
tests/test_infer_ops/triton/test_llama_act_combine.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
from
colossalai.kernel.triton.llama_act_combine_kernel
import
LlamaActCombine
try
:
import
triton
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
'11.4'
)
BATCH_SIZE
=
4
SEQ_LEN
=
16
HIDDEN_SIZE
=
32
def
SwiGLU
(
x
):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size
=
x
.
shape
[
-
1
]
assert
size
%
2
==
0
,
"axis size must be divisible by 2"
x1
,
x2
=
torch
.
split
(
x
,
size
//
2
,
-
1
)
return
x1
*
(
x2
*
torch
.
sigmoid
(
x2
.
to
(
torch
.
float32
)).
to
(
x
.
dtype
))
@
pytest
.
mark
.
skipif
(
not
(
HAS_TRITON
and
TRITON_CUDA_SUPPORT
),
reason
=
"requires triton"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
])
def
test_llama_act_combine
(
dtype
:
str
):
x_gate
=
torch
.
randn
(
BATCH_SIZE
,
SEQ_LEN
,
HIDDEN_SIZE
*
2
,
dtype
=
dtype
).
cuda
()
x_gate_torch
=
nn
.
Parameter
(
x_gate
.
detach
().
clone
())
x_gate_kernel
=
nn
.
Parameter
(
x_gate
.
detach
().
clone
())
x_up
=
torch
.
randn
(
BATCH_SIZE
,
SEQ_LEN
,
HIDDEN_SIZE
,
dtype
=
dtype
).
cuda
()
x_up_torch
=
nn
.
Parameter
(
x_up
.
detach
().
clone
())
x_up_kernel
=
nn
.
Parameter
(
x_up
.
detach
().
clone
())
torch_out
=
SwiGLU
(
x_gate_torch
)
*
x_up_torch
kernel_out
=
LlamaActCombine
.
apply
(
x_gate_kernel
,
x_up_kernel
)
atol
=
1e-5
if
dtype
==
torch
.
float32
else
5e-2
assert
torch
.
allclose
(
torch_out
,
kernel_out
,
atol
=
atol
)
torch_out
.
mean
().
backward
()
kernel_out
.
mean
().
backward
()
assert
all
(
grad
is
not
None
for
grad
in
[
x_gate_torch
.
grad
,
x_up_torch
.
grad
,
x_gate_kernel
.
grad
,
x_up_kernel
.
grad
])
assert
torch
.
allclose
(
x_gate_torch
.
grad
,
x_gate_kernel
.
grad
,
atol
=
atol
)
assert
torch
.
allclose
(
x_up_torch
.
grad
,
x_up_kernel
.
grad
,
atol
=
atol
)
if
__name__
==
'__main__'
:
test_llama_act_combine
(
torch
.
float16
)
tests/test_infer_ops/triton/test_llama_context_attention.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton
import
llama_context_attn_fwd
from
tests.test_infer_ops.triton.kernel_utils
import
torch_context_attention
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_llama_context_attention
():
bs
=
4
head_num
=
8
seq_len
=
1024
head_dim
=
64
query
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
max_input_len
=
seq_len
b_start
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
b_len
=
torch
.
zeros
((
bs
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
for
i
in
range
(
bs
):
b_start
[
i
]
=
i
*
seq_len
b_len
[
i
]
=
seq_len
o
=
torch
.
randn
((
bs
*
seq_len
,
head_num
,
head_dim
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
llama_context_attn_fwd
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
o
,
b_start
,
b_len
,
max_input_len
)
torch_out
=
torch_context_attention
(
query
.
clone
(),
k
.
clone
(),
v
.
clone
(),
bs
,
seq_len
,
head_num
,
head_dim
)
assert
torch
.
allclose
(
torch_out
.
cpu
(),
o
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
),
"outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_llama_context_attention
()
tests/test_infer_ops/triton/test_self_attention_nonfusion.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
import
torch.nn.functional
as
F
from
packaging
import
version
try
:
import
triton
from
colossalai.kernel.triton.qkv_matmul_kernel
import
qkv_gemm_4d_kernel
from
colossalai.kernel.triton.self_attention_nofusion
import
self_attention_compute_using_triton
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_qkv_matmul
():
qkv
=
torch
.
randn
((
4
,
24
,
64
*
3
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
scale
=
1.2
head_size
=
32
batches
=
qkv
.
shape
[
0
]
d_model
=
qkv
.
shape
[
-
1
]
//
3
num_of_heads
=
d_model
//
head_size
q
=
qkv
[:,
:,
:
d_model
]
k
=
qkv
[:,
:,
d_model
:
d_model
*
2
]
q
=
q
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
k
=
k
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
q_copy
=
q
.
clone
()
k_copy
=
k
.
clone
()
q
=
torch
.
transpose
(
q
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
2
,
3
).
contiguous
()
torch_ouput
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
q
,
k
)
torch_ouput
*=
1.2
q
,
k
=
q_copy
,
k_copy
batches
,
M
,
H
,
K
=
q
.
shape
N
=
k
.
shape
[
1
]
score_output
=
torch
.
empty
((
batches
,
H
,
M
,
N
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
grid
=
lambda
meta
:
(
batches
,
H
,
triton
.
cdiv
(
M
,
meta
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
meta
[
"BLOCK_SIZE_N"
]),
)
K
=
q
.
shape
[
3
]
qkv_gemm_4d_kernel
[
grid
](
q
,
k
,
score_output
,
M
,
N
,
K
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
3
),
k
.
stride
(
1
),
score_output
.
stride
(
0
),
score_output
.
stride
(
1
),
score_output
.
stride
(
2
),
score_output
.
stride
(
3
),
scale
=
scale
,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M
=
64
,
BLOCK_SIZE_N
=
32
,
BLOCK_SIZE_K
=
32
,
GROUP_SIZE_M
=
8
,
)
check
=
torch
.
allclose
(
torch_ouput
.
cpu
(),
score_output
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-5
)
assert
check
is
True
,
"the outputs of triton and torch are not matched"
def
self_attention_compute_using_torch
(
qkv
,
input_mask
,
scale
,
head_size
):
batches
=
qkv
.
shape
[
0
]
d_model
=
qkv
.
shape
[
-
1
]
//
3
num_of_heads
=
d_model
//
head_size
q
=
qkv
[:,
:,
:
d_model
]
k
=
qkv
[:,
:,
d_model
:
d_model
*
2
]
v
=
qkv
[:,
:,
d_model
*
2
:]
q
=
q
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
k
=
k
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
v
=
v
.
view
(
batches
,
-
1
,
num_of_heads
,
head_size
)
q
=
torch
.
transpose
(
q
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
1
,
2
).
contiguous
()
v
=
torch
.
transpose
(
v
,
1
,
2
).
contiguous
()
k
=
torch
.
transpose
(
k
,
-
1
,
-
2
).
contiguous
()
score_output
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
q
,
k
)
score_output
*=
scale
softmax_output
=
F
.
softmax
(
score_output
,
dim
=-
1
)
res
=
torch
.
einsum
(
"bnij,bnjk->bnik"
,
softmax_output
,
v
)
res
=
torch
.
transpose
(
res
,
1
,
2
)
res
=
res
.
contiguous
()
return
res
.
view
(
batches
,
-
1
,
d_model
),
score_output
,
softmax_output
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_self_atttention_test
():
qkv
=
torch
.
randn
((
4
,
24
,
64
*
3
),
device
=
"cuda"
,
dtype
=
torch
.
float16
)
data_output_torch
,
score_output_torch
,
softmax_output_torch
=
self_attention_compute_using_torch
(
qkv
.
clone
(),
input_mask
=
None
,
scale
=
1.2
,
head_size
=
32
)
data_output_triton
=
self_attention_compute_using_triton
(
qkv
.
clone
(),
alibi
=
None
,
head_size
=
32
,
scale
=
1.2
,
input_mask
=
None
,
layer_past
=
None
,
use_flash
=
False
,
triangular
=
True
,
)
check
=
torch
.
allclose
(
data_output_triton
.
cpu
(),
data_output_torch
.
cpu
(),
rtol
=
1e-4
,
atol
=
1e-2
)
assert
check
is
True
,
"the triton output is not matched with torch output"
if
__name__
==
"__main__"
:
test_qkv_matmul
()
test_self_atttention_test
()
tests/test_infer_ops/triton/test_softmax.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
try
:
from
colossalai.kernel.triton.softmax
import
softmax
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_softmax_op
():
data_samples
=
[
torch
.
randn
((
3
,
4
,
5
,
32
),
device
=
"cuda"
,
dtype
=
torch
.
float32
),
torch
.
randn
((
320
,
320
,
78
),
device
=
"cuda"
,
dtype
=
torch
.
float32
),
torch
.
randn
((
2345
,
4
,
5
,
64
),
device
=
"cuda"
,
dtype
=
torch
.
float16
),
]
for
data
in
data_samples
:
module
=
nn
.
Softmax
(
dim
=-
1
)
data_torch_out
=
module
(
data
)
data_triton_out
=
softmax
(
data
)
check
=
torch
.
allclose
(
data_torch_out
.
cpu
(),
data_triton_out
.
cpu
(),
rtol
=
1e-3
,
atol
=
1e-3
)
assert
check
is
True
,
"softmax outputs from triton and torch are not matched"
if
__name__
==
"__main__"
:
test_softmax_op
()
tests/test_infer_ops/triton/test_token_attn_fwd.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
from
colossalai.kernel.triton.token_attention_kernel
import
token_attention_fwd
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
import
importlib.util
HAS_LIGHTLLM_KERNEL
=
True
if
importlib
.
util
.
find_spec
(
"lightllm"
)
is
None
:
HAS_LIGHTLLM_KERNEL
=
False
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>=
version
.
parse
(
"11.6"
)
def
torch_att
(
xq
,
xk
,
xv
,
bs
,
seqlen
,
num_head
,
head_dim
):
xq
=
xq
.
view
(
bs
,
1
,
num_head
,
head_dim
)
xk
=
xk
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
xv
=
xv
.
view
(
bs
,
seqlen
,
num_head
,
head_dim
)
logics
=
torch
.
sum
(
xq
*
xk
,
dim
=
3
,
keepdim
=
False
)
*
1
/
(
head_dim
**
0.5
)
prob
=
torch
.
softmax
(
logics
,
dim
=
1
)
prob
=
prob
.
view
(
bs
,
seqlen
,
num_head
,
1
)
return
torch
.
sum
(
prob
*
xv
,
dim
=
1
,
keepdim
=
False
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
or
not
HAS_LIGHTLLM_KERNEL
,
reason
=
"triton requires cuda version to be higher than 11.4 or not install lightllm"
,
)
def
test
():
Z
,
head_num
,
seq_len
,
head_dim
=
22
,
112
//
8
,
2048
,
128
dtype
=
torch
.
float16
q
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
v
=
torch
.
empty
((
Z
*
seq_len
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
o
=
torch
.
empty
((
Z
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.3
,
std
=
0.2
)
alibi
=
torch
.
zeros
((
head_num
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
max_kv_cache_len
=
seq_len
kv_cache_start_loc
=
torch
.
zeros
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_loc
=
torch
.
zeros
((
Z
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
=
torch
.
ones
((
Z
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
[:]
=
seq_len
kv_cache_start_loc
[
0
]
=
0
kv_cache_start_loc
[
1
]
=
seq_len
kv_cache_start_loc
[
2
]
=
2
*
seq_len
kv_cache_start_loc
[
3
]
=
3
*
seq_len
for
i
in
range
(
Z
):
kv_cache_loc
[
i
,
:]
=
torch
.
arange
(
i
*
seq_len
,
(
i
+
1
)
*
seq_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
token_attention_fwd
(
q
,
k
,
v
,
o
,
kv_cache_loc
,
kv_cache_start_loc
,
kv_cache_seq_len
,
max_kv_cache_len
,
alibi
=
alibi
)
torch_out
=
torch_att
(
q
,
k
,
v
,
Z
,
seq_len
,
head_num
,
head_dim
)
print
(
"max "
,
torch
.
max
(
torch
.
abs
(
torch_out
-
o
)))
print
(
"mean "
,
torch
.
mean
(
torch
.
abs
(
torch_out
-
o
)))
assert
torch
.
allclose
(
torch_out
,
o
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test
()
tests/test_infer_ops/triton/test_token_softmax.py
deleted
100644 → 0
View file @
9102d655
import
pytest
import
torch
from
packaging
import
version
try
:
pass
from
colossalai.kernel.triton.token_attention_kernel
import
token_attn_softmax_fwd
HAS_TRITON
=
True
except
ImportError
:
HAS_TRITON
=
False
print
(
"please install triton from https://github.com/openai/triton"
)
TRITON_CUDA_SUPPORT
=
version
.
parse
(
torch
.
version
.
cuda
)
>
version
.
parse
(
"11.4"
)
@
pytest
.
mark
.
skipif
(
not
TRITON_CUDA_SUPPORT
or
not
HAS_TRITON
,
reason
=
"triton requires cuda version to be higher than 11.4"
)
def
test_softmax
():
import
torch
batch_size
,
seq_len
,
head_num
,
head_dim
=
4
,
1025
,
12
,
128
dtype
=
torch
.
float16
Logics
=
torch
.
empty
((
head_num
,
batch_size
*
seq_len
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
10
)
ProbOut
=
torch
.
empty
((
head_num
,
batch_size
*
seq_len
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.4
,
std
=
0.2
)
kv_cache_start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_cache_seq_len
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
batch_size
):
kv_cache_start_loc
[
i
]
=
i
*
seq_len
kv_cache_seq_len
[
i
]
=
seq_len
token_attn_softmax_fwd
(
Logics
,
kv_cache_start_loc
,
kv_cache_seq_len
,
ProbOut
,
seq_len
)
torch_out
=
Logics
.
reshape
(
head_num
*
batch_size
,
-
1
).
softmax
(
-
1
).
reshape
(
head_num
,
batch_size
*
seq_len
)
o
=
ProbOut
print
(
"max "
,
torch
.
max
(
torch
.
abs
(
torch_out
-
o
)))
print
(
"mean "
,
torch
.
mean
(
torch
.
abs
(
torch_out
-
o
)))
assert
torch
.
allclose
(
torch_out
,
o
,
atol
=
1e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
test_softmax
()
tests/test_lazy/test_models.py
View file @
14846934
import
pytest
from
lazy_init_utils
import
SUPPORT_LAZY
,
check_lazy_init
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
COMMON_MODELS
,
IS_FAST_TEST
,
model_zoo
@
pytest
.
mark
.
skipif
(
not
SUPPORT_LAZY
,
reason
=
"requires torch >= 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"subset"
,
[
"torchvision"
,
"diffusers"
,
"timm"
,
"transformers"
,
"torchaudio"
,
"deepfm"
,
"dlrm"
])
@
pytest
.
mark
.
parametrize
(
"subset"
,
[
COMMON_MODELS
]
if
IS_FAST_TEST
else
[
"torchvision"
,
"diffusers"
,
"timm"
,
"transformers"
,
"torchaudio"
,
"deepfm"
,
"dlrm"
],
)
@
pytest
.
mark
.
parametrize
(
"default_device"
,
[
"cpu"
,
"cuda"
])
def
test_torchvision_models_lazy_init
(
subset
,
default_device
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
subset
,
allow_empty
=
True
)
for
name
,
entry
in
sub_model_zoo
.
items
():
# TODO(ver217): lazy init does not support weight norm, skip these models
if
name
in
(
"torchaudio_wav2vec2_base"
,
"torchaudio_hubert_base"
)
or
name
.
startswith
(
...
...
tests/test_pipeline/test_p2p_communication.py
View file @
14846934
...
...
@@ -5,43 +5,69 @@ import torch.distributed as dist
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
,
create_send_metadata
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
WORLD_SIZE
=
2
def
check_p2p_communication
():
pg_mesh
=
ProcessGroupMesh
(
2
)
pg_mesh
=
ProcessGroupMesh
(
WORLD_SIZE
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
0
)
p2p
=
PipelineP2PCommunication
(
stage_manager
)
rank
=
dist
.
get_rank
()
tensor
=
torch
.
ones
(
1
,
device
=
get_accelerator
().
get_current_device
())
data
=
[
"tensor"
,
tensor
,
[
tensor
],
{
"tensor"
:
tensor
},
]
if
rank
==
0
:
p2p
.
send_forward
(
tensor
)
p2p
.
send_forward
([
tensor
])
p2p
.
send_forward
({
"tensor"
:
tensor
})
else
:
obj
=
p2p
.
recv_forward
()
assert
torch
.
equal
(
obj
,
tensor
)
obj
=
p2p
.
recv_forward
()
assert
type
(
obj
)
==
list
and
len
(
obj
)
==
1
and
torch
.
equal
(
obj
[
0
],
tensor
)
obj
=
p2p
.
recv_forward
()
assert
type
(
obj
)
==
dict
and
"tensor"
in
obj
and
torch
.
equal
(
obj
[
"tensor"
],
tensor
)
for
obj
in
data
:
p2p
.
send_forward
(
obj
)
for
i
in
range
(
len
(
data
)):
recv_obj
=
p2p
.
send_forward_recv_backward
(
data
[
i
],
send_prior_fallback
=
False
)
assert
recv_obj
==
data
[
-
(
i
+
1
)]
elif
rank
==
1
:
for
obj
in
data
:
recv_obj
=
p2p
.
recv_forward
()
assert
recv_obj
==
obj
for
i
in
range
(
len
(
data
)):
p2p
.
send_backward
(
data
[
-
(
i
+
1
)])
recv_obj
=
p2p
.
recv_forward
()
assert
recv_obj
==
data
[
i
]
if
rank
==
1
:
p2p
.
send_backward
(
tensor
)
p2p
.
send_backward
([
tensor
])
p2p
.
send_backward
({
"tensor"
:
tensor
})
else
:
obj
=
p2p
.
recv_backward
()
assert
torch
.
equal
(
obj
,
tensor
)
obj
=
p2p
.
recv_backward
()
assert
type
(
obj
)
==
list
and
len
(
obj
)
==
1
and
torch
.
equal
(
obj
[
0
],
tensor
)
obj
=
p2p
.
recv_backward
()
assert
type
(
obj
)
==
dict
and
"tensor"
in
obj
and
torch
.
equal
(
obj
[
"tensor"
],
tensor
)
for
obj
in
data
:
p2p
.
send_backward
(
obj
)
for
i
in
range
(
len
(
data
)):
recv_obj
=
p2p
.
send_backward_recv_forward
(
data
[
i
],
send_prior_fallback
=
True
)
assert
recv_obj
==
data
[
-
(
i
+
1
)]
elif
rank
==
0
:
for
obj
in
data
:
recv_obj
=
p2p
.
recv_backward
()
assert
recv_obj
==
obj
for
i
in
range
(
len
(
data
)):
recv_obj
=
p2p
.
recv_backward
()
p2p
.
send_forward
(
data
[
-
(
i
+
1
)])
assert
recv_obj
==
data
[
i
]
if
rank
==
0
:
recv_obj
=
p2p
.
send_forward_recv_backward
(
tensor
,
send_metadata
=
False
,
metadata_recv
=
create_send_metadata
(
tensor
),
)
assert
recv_obj
==
tensor
elif
rank
==
1
:
recv_obj
=
p2p
.
recv_forward
(
metadata_recv
=
create_send_metadata
(
tensor
))
assert
recv_obj
==
tensor
p2p
.
send_backward
(
tensor
,
send_metadata
=
False
)
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -52,7 +78,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_pipeline_p2p
():
spawn
(
run_dist
,
2
)
spawn
(
run_dist
,
WORLD_SIZE
)
if
__name__
==
"__main__"
:
...
...
tests/test_pipeline/test_schedule/test_interleaved.py
View file @
14846934
...
...
@@ -4,6 +4,7 @@ from types import MethodType
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
colossalai
...
...
@@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.pipeline.schedule.interleaved_pp
import
InterleavedSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
NUM_LAYER
=
8
DIM
=
4
class
MlpModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MlpModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
4
,
8
)
self
.
linear2
=
nn
.
Linear
(
8
,
8
)
self
.
linear3
=
nn
.
Linear
(
8
,
8
)
self
.
linear4
=
nn
.
Linear
(
8
,
8
)
self
.
linear5
=
nn
.
Linear
(
8
,
8
)
self
.
linear6
=
nn
.
Linear
(
8
,
8
)
self
.
linear7
=
nn
.
Linear
(
8
,
8
)
self
.
linear8
=
nn
.
Linear
(
8
,
4
)
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
([
nn
.
Linear
(
DIM
,
DIM
)
for
_
in
range
(
NUM_LAYER
)])
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
self
.
linear4
(
x
)
x
=
self
.
linear5
(
x
)
x
=
self
.
linear6
(
x
)
x
=
self
.
linear7
(
x
)
x
=
self
.
linear8
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
return
x
...
...
@@ -44,70 +35,72 @@ def pp_linear_fwd(
data
:
torch
.
Tensor
=
None
,
input_obj
:
torch
.
Tensor
=
None
,
stage_mgr
:
PipelineStageManager
=
None
,
num_chunks
:
int
=
None
,
model_chunk_id
:
int
=
None
,
):
if
stage_mgr
.
is_first_stage
()
and
model_chunk_id
==
0
:
return
{
"input_obj"
:
forward
(
data
)}
elif
stage_mgr
.
is_last_stage
()
and
model_chunk_id
==
num_chunks
-
1
:
return
forward
(
input_obj
)
else
:
return
{
"input_obj"
:
forward
(
input_obj
)}
with
stage_mgr
.
switch_model_chunk_id
(
model_chunk_id
):
if
stage_mgr
.
is_first_stage
():
return
{
"input_obj"
:
forward
(
data
)}
elif
stage_mgr
.
is_last_stage
():
return
forward
(
input_obj
)
else
:
return
{
"input_obj"
:
forward
(
input_obj
)}
@
parameterize
(
"num_micro_batches"
,
[
4
,
8
,
12
])
def
examine_pp
(
num_micro_batches
):
def
run_pp
(
rank
:
int
,
world_size
:
int
,
port
:
int
,
num_microbatch
:
int
,
batch_size
:
int
,
num_model_chunk
:
int
,
):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size
=
torch
.
distributed
.
get_world_size
()
local_rank
=
torch
.
distributed
.
get_rank
()
seed_all
(
1453
)
NUM_MICRO_BATCHS
=
num_micro_batches
BATCH_SIZE
=
num_micro_batches
NUM_CHUNKS
=
2
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
"localhost"
)
# create model
seed_all
(
1453
)
torch_model
=
MlpModel
().
cuda
()
pp_model
=
copy
.
deepcopy
(
torch_model
).
cuda
()
DP_DIM
,
PP_DIM
,
TP_DIM
=
0
,
1
,
2
pg_mesh
=
ProcessGroupMesh
(
1
,
world_size
,
1
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
PP_DIM
,
is_virtual
=
True
)
schedule
=
InterleavedSchedule
(
NUM_MICRO_BATCHS
,
NUM_CHUNKS
,
stage_manager
)
pg_mesh
=
ProcessGroupMesh
(
world_size
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
pipeline_axis
=
0
,
enable_interleave
=
True
,
num_model_chunks
=
num_model_chunk
)
schedule
=
InterleavedSchedule
(
stage_manager
=
stage_manager
,
num_model_chunks
=
num_model_chunk
,
num_microbatch
=
num_microbatch
,
)
sharded_model
=
torch
.
nn
.
ModuleList
()
for
idx
,
(
_
,
sub_model
)
in
enumerate
(
pp_model
.
named_children
()
):
if
idx
%
(
world_size
)
==
local_
rank
:
for
idx
,
sub_model
in
enumerate
(
pp_model
.
layers
):
if
idx
%
world_size
==
rank
:
sub_model
.
_forward
=
sub_model
.
forward
sub_model
.
forward
=
MethodType
(
partial
(
pp_linear_fwd
,
stage_mgr
=
stage_manager
,
num_chunks
=
NUM_CHUNKS
,
model_chunk_id
=
len
(
sharded_model
)
),
partial
(
pp_linear_fwd
,
stage_mgr
=
stage_manager
,
model_chunk_id
=
len
(
sharded_model
)),
sub_model
.
_forward
,
)
sharded_model
.
append
(
sub_model
.
cuda
())
assert
len
(
sharded_model
)
==
num_model_chunk
,
"num_model_chunk is not correct"
# create optimizer
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
)
pp_optimizer
=
OptimizerWrapper
(
torch
.
optim
.
SGD
(
sharded_model
.
parameters
(),
lr
=
1
))
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
e-5
)
pp_optimizer
=
OptimizerWrapper
(
torch
.
optim
.
SGD
(
sharded_model
.
parameters
(),
lr
=
1
e-5
))
# create
seed_all
(
1453
)
if
local_rank
==
0
:
input_list
=
[
torch
.
rand
(
BATCH_SIZE
,
4
).
cuda
()]
else
:
input_list
=
[
torch
.
zeros
(
BATCH_SIZE
,
4
).
cuda
()]
torch
.
distributed
.
all_reduce
(
input_list
[
0
])
# create data
seed_all
(
115
)
input_list
=
[
torch
.
rand
(
batch_size
,
DIM
).
cuda
()]
dist
.
all_reduce
(
input_list
[
0
])
criterion
=
lambda
x
,
y
:
torch
.
mean
(
x
)
def
criterion
(
x
,
*
args
,
**
kwargs
):
return
(
x
*
x
).
mean
()
# forward and backward
torch_output
=
torch_model
(
input_list
[
0
])
torch_loss
=
criterion
(
torch_output
,
_
)
torch_loss
=
criterion
(
torch_output
)
torch_loss
.
backward
()
pp_ret
=
schedule
.
forward_backward_step
(
...
...
@@ -115,45 +108,60 @@ def examine_pp(num_micro_batches):
)
# check loss
if
stage_manager
.
is_last_stage
():
if
stage_manager
.
is_last_stage
(
ignore_chunk
=
True
):
assert
torch
.
allclose
(
torch_loss
,
pp_ret
[
"loss"
])
# check gradients
torch_grad
=
[]
for
torch_p
in
torch_model
.
parameters
():
torch_grad
.
append
(
torch_p
.
grad
.
data
)
for
idx
,
pp_p
in
enumerate
(
sharded_model
.
parameters
()):
if
idx
<
2
:
assert
torch
.
allclose
(
torch_grad
[
idx
+
local_rank
*
2
],
pp_p
.
grad
.
data
)
else
:
assert
torch
.
allclose
(
torch_grad
[
idx
+
local_rank
*
2
+
6
],
pp_p
.
grad
.
data
)
for
i
in
range
(
num_model_chunk
):
idx
=
world_size
*
i
+
rank
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
weight
.
grad
,
sharded_model
[
i
].
weight
.
grad
)
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
bias
.
grad
,
sharded_model
[
i
].
bias
.
grad
)
# step
torch_optimizer
.
step
()
pp_optimizer
.
step
()
pp_optimizer
.
zero_grad
()
# check updated param
torch_param
=
[]
for
torch_p
in
torch_model
.
parameters
():
torch_param
.
append
(
torch_p
.
data
)
for
idx
,
pp_p
in
enumerate
(
sharded_model
.
parameters
()):
if
idx
<
2
:
assert
torch
.
allclose
(
torch_param
[
idx
+
local_rank
*
2
],
pp_p
.
data
)
else
:
assert
torch
.
allclose
(
torch_param
[
idx
+
local_rank
*
2
+
6
],
pp_p
.
data
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
"localhost"
)
examine_pp
()
for
i
in
range
(
num_model_chunk
):
idx
=
world_size
*
i
+
rank
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
weight
,
sharded_model
[
i
].
weight
)
assert
torch
.
allclose
(
torch_model
.
layers
[
idx
].
bias
,
sharded_model
[
i
].
bias
)
# forward only
with
torch
.
no_grad
():
torch_output
=
torch_model
(
input_list
[
0
])
torch_loss
=
criterion
(
torch_output
)
pp_ret
=
schedule
.
forward_backward_step
(
sharded_model
,
iter
(
input_list
),
criterion
,
pp_optimizer
,
return_loss
=
True
,
return_outputs
=
True
)
if
stage_manager
.
is_last_stage
(
ignore_chunk
=
True
):
assert
torch
.
allclose
(
torch_loss
,
pp_ret
[
"loss"
])
for
layer
in
sharded_model
:
if
layer
.
weight
.
grad
is
None
:
assert
layer
.
weight
.
grad
is
None
and
layer
.
bias
.
grad
is
None
else
:
assert
torch
.
allclose
(
layer
.
weight
.
grad
,
torch
.
zeros_like
(
layer
.
weight
.
grad
))
assert
torch
.
allclose
(
layer
.
bias
.
grad
,
torch
.
zeros_like
(
layer
.
bias
.
grad
))
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"num_microbatch"
,
[
4
,
12
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
12
])
@
pytest
.
mark
.
parametrize
(
"num_model_chunk"
,
[
2
,
4
])
@
rerun_if_address_is_in_use
()
def
test_pp
():
spawn
(
run_dist
,
4
)
def
test_pp
(
num_microbatch
:
int
,
batch_size
:
int
,
num_model_chunk
:
int
):
assert
NUM_LAYER
%
num_model_chunk
==
0
spawn
(
run_pp
,
nprocs
=
NUM_LAYER
//
num_model_chunk
,
num_microbatch
=
num_microbatch
,
batch_size
=
batch_size
,
num_model_chunk
=
num_model_chunk
,
)
if
__name__
==
"__main__"
:
test_pp
()
test_pp
(
num_microbatch
=
4
,
batch_size
=
4
,
num_model_chunk
=
4
)
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment