Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
efba0f44
Unverified
Commit
efba0f44
authored
Sep 05, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 05, 2023
Browse files
Merge pull request #4612 from hpcaitech/feature/shardformer
[shardformer] update hybrid parallel plugin and fix bugs
parents
ac178ca5
fae6c92e
Changes
77
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1470 additions
and
349 deletions
+1470
-349
tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
...peline/test_pipeline_utils/test_whisper_pipeline_utils.py
+44
-0
tests/test_pipeline/test_schedule/test_interleaved.py
tests/test_pipeline/test_schedule/test_interleaved.py
+161
-0
tests/test_pipeline/test_schedule/test_oneF_oneB.py
tests/test_pipeline/test_schedule/test_oneF_oneB.py
+1
-1
tests/test_pipeline/test_stage_manager.py
tests/test_pipeline/test_stage_manager.py
+0
-9
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
...t_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
+25
-13
tests/test_shardformer/test_layer/test_linear_1d.py
tests/test_shardformer/test_layer/test_linear_1d.py
+51
-24
tests/test_shardformer/test_model/_utils.py
tests/test_shardformer/test_model/_utils.py
+82
-6
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+114
-24
tests/test_shardformer/test_model/test_shard_bloom.py
tests/test_shardformer/test_model/test_shard_bloom.py
+114
-25
tests/test_shardformer/test_model/test_shard_chatglm2.py
tests/test_shardformer/test_model/test_shard_chatglm2.py
+106
-40
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+128
-31
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+116
-39
tests/test_shardformer/test_model/test_shard_opt.py
tests/test_shardformer/test_model/test_shard_opt.py
+112
-39
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+105
-25
tests/test_shardformer/test_model/test_shard_vit.py
tests/test_shardformer/test_model/test_shard_vit.py
+115
-42
tests/test_shardformer/test_model/test_shard_whisper.py
tests/test_shardformer/test_model/test_shard_whisper.py
+196
-30
tests/test_utils/test_activation_checkpointing.py
tests/test_utils/test_activation_checkpointing.py
+0
-1
No files found.
tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py
0 → 100644
View file @
efba0f44
from
colossalai.shardformer.policies.whisper
import
WhisperPolicy
def
test_whisper_pipeline_distribution
():
num_test_cases
=
8
test_dict
=
{
'num_encoder_layers'
:
[
2
,
1
,
3
,
2
,
3
,
2
,
10
,
5
],
'num_decoder_layers'
:
[
2
,
8
,
0
,
2
,
1
,
5
,
6
,
22
],
'num_stages'
:
[
2
,
2
,
2
,
4
,
4
,
4
,
8
,
8
],
'decoder_starting_stage'
:
[
1
,
1
,
2
,
2
,
3
,
1
,
5
,
2
]
}
for
i
in
range
(
num_test_cases
):
_
,
decoder_starting_stage
=
WhisperPolicy
.
distribute_whisper_layers
(
test_dict
[
'num_encoder_layers'
][
i
],
test_dict
[
'num_decoder_layers'
][
i
],
test_dict
[
'num_stages'
][
i
])
assert
test_dict
[
'decoder_starting_stage'
][
i
]
==
decoder_starting_stage
def
test_whisper_pipeline_layers
():
num_test_cases
=
4
test_dict
=
{
'num_encoder_layers'
:
[
2
,
3
,
2
,
4
],
'num_decoder_layers'
:
[
2
,
0
,
2
,
8
],
'num_stages'
:
[
2
,
2
,
4
,
4
],
'layers_per_stage'
:
[[[
0
,
2
],
[
0
,
2
]],
[[
0
,
1
],
[
1
,
3
]],
[[
0
,
1
],
[
1
,
2
],
[
0
,
1
],
[
1
,
2
]],
[[
0
,
4
],
[
0
,
3
],
[
3
,
6
],
[
6
,
8
]]]
}
for
i
in
range
(
num_test_cases
):
layers_per_stage
,
decoder_starting_stage
=
WhisperPolicy
.
distribute_whisper_layers
(
test_dict
[
'num_encoder_layers'
][
i
],
test_dict
[
'num_decoder_layers'
][
i
],
test_dict
[
'num_stages'
][
i
])
for
stage
in
range
(
test_dict
[
'num_stages'
][
i
]):
start_idx
,
end_idx
=
test_dict
[
'layers_per_stage'
][
i
][
stage
]
predicted_start
,
predicted_end
=
WhisperPolicy
.
get_whisper_stage_index
(
layers_per_stage
,
stage
,
decoder_starting_stage
)
assert
start_idx
==
predicted_start
assert
end_idx
==
predicted_end
if
__name__
==
'__main__'
:
test_whisper_pipeline_distribution
()
test_whisper_pipeline_layers
()
tests/test_pipeline/test_schedule/test_interleaved.py
0 → 100644
View file @
efba0f44
import
copy
from
functools
import
partial
from
types
import
MethodType
import
pytest
import
torch
import
torch.nn
as
nn
import
colossalai
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.pipeline.schedule.interleaved_pp
import
InterleavedSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
class
MlpModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MlpModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
4
,
8
)
self
.
linear2
=
nn
.
Linear
(
8
,
8
)
self
.
linear3
=
nn
.
Linear
(
8
,
8
)
self
.
linear4
=
nn
.
Linear
(
8
,
8
)
self
.
linear5
=
nn
.
Linear
(
8
,
8
)
self
.
linear6
=
nn
.
Linear
(
8
,
8
)
self
.
linear7
=
nn
.
Linear
(
8
,
8
)
self
.
linear8
=
nn
.
Linear
(
8
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
self
.
linear4
(
x
)
x
=
self
.
linear5
(
x
)
x
=
self
.
linear6
(
x
)
x
=
self
.
linear7
(
x
)
x
=
self
.
linear8
(
x
)
return
x
def
pp_linear_fwd
(
forward
,
data
:
torch
.
Tensor
=
None
,
input_obj
:
torch
.
Tensor
=
None
,
stage_mgr
:
PipelineStageManager
=
None
,
num_chunks
:
int
=
None
,
model_chunk_id
:
int
=
None
):
if
stage_mgr
.
is_first_stage
()
and
model_chunk_id
==
0
:
return
{
'input_obj'
:
forward
(
data
)}
elif
stage_mgr
.
is_last_stage
()
and
model_chunk_id
==
num_chunks
-
1
:
return
forward
(
input_obj
)
else
:
return
{
'input_obj'
:
forward
(
input_obj
)}
@
parameterize
(
"num_micro_batches"
,
[
4
,
8
,
12
])
def
examine_pp
(
num_micro_batches
):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size
=
torch
.
distributed
.
get_world_size
()
local_rank
=
torch
.
distributed
.
get_rank
()
seed_all
(
1453
)
NUM_MICRO_BATCHS
=
num_micro_batches
BATCH_SIZE
=
num_micro_batches
NUM_CHUNKS
=
2
# create model
torch_model
=
MlpModel
().
cuda
()
pp_model
=
copy
.
deepcopy
(
torch_model
).
cuda
()
DP_DIM
,
PP_DIM
,
TP_DIM
=
0
,
1
,
2
pg_mesh
=
ProcessGroupMesh
(
1
,
world_size
,
1
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
PP_DIM
,
is_virtual
=
True
)
schedule
=
InterleavedSchedule
(
NUM_MICRO_BATCHS
,
NUM_CHUNKS
,
stage_manager
)
sharded_model
=
torch
.
nn
.
ModuleList
()
for
idx
,
(
_
,
sub_model
)
in
enumerate
(
pp_model
.
named_children
()):
if
idx
%
(
world_size
)
==
local_rank
:
sub_model
.
_forward
=
sub_model
.
forward
sub_model
.
forward
=
MethodType
(
partial
(
pp_linear_fwd
,
stage_mgr
=
stage_manager
,
num_chunks
=
NUM_CHUNKS
,
model_chunk_id
=
len
(
sharded_model
)),
sub_model
.
_forward
)
sharded_model
.
append
(
sub_model
.
cuda
())
# create optimizer
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
)
pp_optimizer
=
OptimizerWrapper
(
torch
.
optim
.
SGD
(
sharded_model
.
parameters
(),
lr
=
1
))
# create
seed_all
(
1453
)
if
local_rank
==
0
:
input_list
=
[
torch
.
rand
(
BATCH_SIZE
,
4
).
cuda
()]
else
:
input_list
=
[
torch
.
zeros
(
BATCH_SIZE
,
4
).
cuda
()]
torch
.
distributed
.
all_reduce
(
input_list
[
0
])
criterion
=
lambda
x
,
y
:
torch
.
mean
(
x
)
# forward and backward
torch_output
=
torch_model
(
input_list
[
0
])
torch_loss
=
criterion
(
torch_output
,
_
)
torch_loss
.
backward
()
pp_ret
=
schedule
.
forward_backward_step
(
sharded_model
,
pp_optimizer
,
iter
(
input_list
),
criterion
,
return_loss
=
True
,
return_outputs
=
True
)
# check loss
if
stage_manager
.
is_last_stage
():
assert
torch
.
allclose
(
torch_loss
,
pp_ret
[
'loss'
])
# check gradients
torch_grad
=
[]
for
torch_p
in
torch_model
.
parameters
():
torch_grad
.
append
(
torch_p
.
grad
.
data
)
for
idx
,
pp_p
in
enumerate
(
sharded_model
.
parameters
()):
if
idx
<
2
:
assert
torch
.
allclose
(
torch_grad
[
idx
+
local_rank
*
2
],
pp_p
.
grad
.
data
)
else
:
assert
torch
.
allclose
(
torch_grad
[
idx
+
local_rank
*
2
+
6
],
pp_p
.
grad
.
data
)
# step
torch_optimizer
.
step
()
pp_optimizer
.
step
()
# check updated param
torch_param
=
[]
for
torch_p
in
torch_model
.
parameters
():
torch_param
.
append
(
torch_p
.
data
)
for
idx
,
pp_p
in
enumerate
(
sharded_model
.
parameters
()):
if
idx
<
2
:
assert
torch
.
allclose
(
torch_param
[
idx
+
local_rank
*
2
],
pp_p
.
data
)
else
:
assert
torch
.
allclose
(
torch_param
[
idx
+
local_rank
*
2
+
6
],
pp_p
.
data
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
examine_pp
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_pp
():
spawn
(
run_dist
,
4
)
if
__name__
==
'__main__'
:
test_pp
()
tests/test_pipeline/test_schedule/test_oneF_oneB.py
View file @
efba0f44
...
...
@@ -61,7 +61,7 @@ def examine_pp():
DP_DIM
,
PP_DIM
,
TP_DIM
=
0
,
1
,
2
pg_mesh
=
ProcessGroupMesh
(
1
,
world_size
,
1
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
PP_DIM
)
schedule
=
OneForwardOneBackwardSchedule
(
NUM_MICRO_BATCHS
,
stage_manager
)
schedule
=
OneForwardOneBackwardSchedule
(
stage_manager
,
num_microbatches
=
NUM_MICRO_BATCHS
)
for
idx
,
(
_
,
sub_model
)
in
enumerate
(
pp_model
.
named_children
()):
if
idx
%
(
world_size
)
==
local_rank
:
...
...
tests/test_pipeline/test_stage_manager.py
View file @
efba0f44
...
...
@@ -49,15 +49,6 @@ def check_stage_manager():
next_rank
=
ranks_in_group
[
ranks_in_group
.
index
(
rank
)
+
1
]
assert
stage_manager
.
get_next_rank
()
==
next_rank
# check virtual stage
stage_manager
.
set_num_virtual_stages
(
PP_SIZE
*
2
)
assert
stage_manager
.
num_virtual_stages
==
PP_SIZE
*
2
stage_manager
.
set_virtual_stage
(
stage_manager
.
stage
*
2
)
assert
stage_manager
.
virtual_stage
==
stage_manager
.
stage
*
2
with
stage_manager
.
switch_virtual_stage
(
stage_manager
.
stage
*
2
+
1
):
assert
stage_manager
.
virtual_stage
==
stage_manager
.
stage
*
2
+
1
assert
stage_manager
.
virtual_stage
==
stage_manager
.
stage
*
2
# check p2p groups
for
prev
,
cur
in
zip
(
ranks_in_group
[:
-
1
],
ranks_in_group
[
1
:]):
if
rank
in
[
prev
,
cur
]:
...
...
tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py
View file @
efba0f44
...
...
@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return
rearanged_tensor
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_conv_1d_col
(
lazy_init
:
bool
):
def
check_linear_conv_1d_col
(
lazy_init
:
bool
,
seq_parallel
:
bool
,
overlap
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear
=
Conv1D
(
192
,
48
).
cuda
()
with
ctx
:
...
...
@@ -62,7 +61,9 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear_conv_col
=
GPT2FusedLinearConv1D_Col
.
from_native_module
(
linear_copy
,
process_group
=
None
,
gather_output
=
True
,
n_fused
=
3
)
seq_parallel
=
seq_parallel
,
n_fused
=
3
,
overlap
=
overlap
)
assert
linear
.
weight
.
shape
==
torch
.
Size
([
48
,
192
])
assert
linear
.
bias
.
shape
==
torch
.
Size
([
192
])
...
...
@@ -76,10 +77,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear
.
load_state_dict
(
linear_conv_col
.
state_dict
())
# check computation correctness
x
=
torch
.
rand
(
4
,
48
).
cuda
()
x
=
torch
.
rand
(
1
,
4
,
48
).
cuda
()
out
=
linear
(
x
)
gather_out
=
linear_conv_col
(
x
)
assert_close
(
rearrange
(
out
,
1
),
gather_out
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
if
seq_parallel
is
False
else
torch
.
chunk
(
x
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
gather_out
=
linear_conv_col
(
x_for_shard
)
assert_close
(
rearrange
(
out
,
-
1
),
gather_out
)
# check backward correctness
out
.
sum
().
backward
()
...
...
@@ -89,14 +91,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert_close
(
target_grad
,
linear_conv_col
.
weight
.
grad
)
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_conv_1d_row
(
lazy_init
:
bool
):
def
check_linear_conv_1d_row
(
lazy_init
:
bool
,
seq_parallel
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear
=
Conv1D
(
192
,
48
).
cuda
()
with
ctx
:
linear_copy
=
Conv1D
(
192
,
48
).
cuda
()
linear_row
=
GPT2FusedLinearConv1D_Row
.
from_native_module
(
linear_copy
,
process_group
=
None
,
parallel_input
=
False
)
linear_row
=
GPT2FusedLinearConv1D_Row
.
from_native_module
(
linear_copy
,
process_group
=
None
,
parallel_input
=
False
,
seq_parallel
=
seq_parallel
)
assert
linear
.
weight
.
shape
==
torch
.
Size
([
48
,
192
])
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
24
,
192
])
...
...
@@ -109,10 +113,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear
.
load_state_dict
(
linear_row
.
state_dict
())
# check computation correctness
x
=
torch
.
rand
(
4
,
48
).
cuda
()
x
=
torch
.
rand
(
1
,
4
,
48
).
cuda
()
out
=
linear
(
x
)
gather_out
=
linear_row
(
x
)
assert_close
(
out
,
gather_out
)
target_out
=
out
if
seq_parallel
is
False
else
torch
.
chunk
(
out
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
assert_close
(
target_out
,
gather_out
)
# check backward correctness
out
.
sum
().
backward
()
...
...
@@ -123,12 +128,19 @@ def check_linear_conv_1d_row(lazy_init: bool):
assert_close
(
target_grad
,
linear_row
.
weight
.
grad
)
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
@
parameterize
(
'seq_parallel'
,
[
False
,
True
])
@
parameterize
(
'overlap'
,
[
True
])
def
check_gpt2_qkv_fused_linear_1d
(
lazy_init
:
bool
,
seq_parallel
:
bool
,
overlap
:
bool
):
check_linear_conv_1d_col
(
lazy_init
,
seq_parallel
,
overlap
)
check_linear_conv_1d_row
(
lazy_init
,
seq_parallel
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# test for linear conv
check_linear_conv_1d_col
()
check_linear_conv_1d_row
()
check_gpt2_qkv_fused_linear_1d
()
@
rerun_if_address_is_in_use
()
...
...
tests/test_shardformer/test_layer/test_linear_1d.py
View file @
efba0f44
...
...
@@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_1d_col
(
lazy_init
:
bool
):
def
check_linear_1d_col
(
lazy_init
:
bool
,
seq_parallel
:
bool
,
overlap
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear
=
nn
.
Linear
(
32
,
128
).
cuda
()
with
ctx
:
linear_copy
=
nn
.
Linear
(
32
,
128
).
cuda
()
linear_col
=
Linear1D_Col
.
from_native_module
(
linear_copy
,
process_group
=
None
,
gather_output
=
True
)
linear_col
=
Linear1D_Col
.
from_native_module
(
linear_copy
,
process_group
=
None
,
gather_output
=
True
,
seq_parallel
=
seq_parallel
,
overlap
=
overlap
)
# ensure that the parameters are distributed
assert
is_distributed_tensor
(
linear_col
.
weight
)
...
...
@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
linear_col
.
load_state_dict
(
linear
.
state_dict
())
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
# [batch_size, seq_len, hidden_size]
x
=
torch
.
rand
(
2
,
4
,
32
).
cuda
()
x_for_unshard
=
x
.
expand_as
(
x
.
clone
())
x_for_unshard
.
requires_grad_
(
True
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
if
seq_parallel
is
False
else
torch
.
chunk
(
x
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
x_for_shard
.
requires_grad_
(
True
)
out
=
linear
(
x_for_unshard
)
...
...
@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
# check the input gradients
assert
x_for_shard
.
grad
is
not
None
assert
x_for_unshard
.
grad
is
not
None
assert_close
(
x_for_unshard
.
grad
,
x_for_shard
.
grad
)
target_unshard_gard
=
x_for_unshard
.
grad
if
seq_parallel
is
False
else
torch
.
chunk
(
x_for_unshard
.
grad
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
assert_close
(
target_unshard_gard
,
x_for_shard
.
grad
)
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_1d_row
(
lazy_init
:
bool
):
def
check_linear_1d_row
(
lazy_init
:
bool
,
seq_parallel
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear
=
nn
.
Linear
(
32
,
128
).
cuda
()
with
ctx
:
linear_copy
=
nn
.
Linear
(
32
,
128
).
cuda
()
linear_row
=
Linear1D_Row
.
from_native_module
(
linear_copy
,
process_group
=
None
,
parallel_input
=
False
)
linear_row
=
Linear1D_Row
.
from_native_module
(
linear_copy
,
process_group
=
None
,
parallel_input
=
False
,
seq_parallel
=
seq_parallel
)
assert
linear_row
.
weight
.
shape
==
torch
.
Size
([
128
,
16
])
assert
linear_row
.
bias
.
shape
==
torch
.
Size
([
128
])
...
...
@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
linear_row
.
load_state_dict
(
linear
.
state_dict
())
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
# [batch_size, seq_len, hidden_size]
x
=
torch
.
rand
(
2
,
4
,
32
).
cuda
()
x_for_unshard
=
x
.
expand_as
(
x
.
clone
())
x_for_unshard
.
requires_grad_
(
True
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
...
...
@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
# run forward
out
=
linear
(
x_for_unshard
)
gather_out
=
linear_row
(
x_for_shard
)
assert_close
(
out
,
gather_out
)
target_out
=
out
if
seq_parallel
is
False
else
torch
.
chunk
(
out
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
assert_close
(
target_out
,
gather_out
)
# check backward correctness
out
.
sum
().
backward
()
...
...
@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
assert_close
(
x_for_unshard
.
grad
,
x_for_shard
.
grad
)
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
def
check_linear_col_plus_row
(
lazy_init
:
bool
):
def
check_linear_col_plus_row
(
lazy_init
:
bool
,
seq_parallel
:
bool
,
overlap
:
bool
):
ctx
=
LazyInitContext
()
if
lazy_init
else
nullcontext
()
linear_1
=
nn
.
Linear
(
32
,
128
).
cuda
()
...
...
@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
with
ctx
:
linear_1_copy
=
nn
.
Linear
(
32
,
128
).
cuda
()
linear_2_copy
=
nn
.
Linear
(
128
,
32
).
cuda
()
linear_col
=
Linear1D_Col
.
from_native_module
(
linear_1_copy
,
process_group
=
None
,
gather_output
=
False
)
linear_row
=
Linear1D_Row
.
from_native_module
(
linear_2_copy
,
process_group
=
None
,
parallel_input
=
True
)
linear_col
=
Linear1D_Col
.
from_native_module
(
linear_1_copy
,
process_group
=
None
,
gather_output
=
False
,
seq_parallel
=
seq_parallel
,
overlap
=
overlap
)
linear_row
=
Linear1D_Row
.
from_native_module
(
linear_2_copy
,
process_group
=
None
,
parallel_input
=
True
,
seq_parallel
=
seq_parallel
)
linear_1
.
load_state_dict
(
linear_col
.
state_dict
())
linear_col
.
load_state_dict
(
linear_1
.
state_dict
())
...
...
@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
linear_row
.
load_state_dict
(
linear_2
.
state_dict
())
# check computation correctness
x
=
torch
.
rand
(
4
,
32
).
cuda
()
# [batch_size, seq_len, hidden_size]
x
=
torch
.
rand
(
2
,
4
,
32
).
cuda
()
x_for_unshard
=
x
.
expand_as
(
x
.
clone
())
x_for_unshard
.
requires_grad_
(
True
)
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
x_for_shard
=
x
.
expand_as
(
x
.
clone
())
if
seq_parallel
is
False
else
torch
.
chunk
(
x
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
x_for_shard
.
requires_grad_
(
True
)
# run forward
unshard_out
=
linear_2
(
linear_1
(
x_for_unshard
))
shard_out
=
linear_row
(
linear_col
(
x_for_shard
))
assert_close
(
unshard_out
,
shard_out
)
target_out
=
unshard_out
if
seq_parallel
is
False
else
torch
.
chunk
(
unshard_out
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
assert_close
(
target_out
,
shard_out
)
# check backward correctness
unshard_out
.
sum
().
backward
()
...
...
@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
# check the input gradients
assert
x_for_shard
.
grad
is
not
None
assert
x_for_unshard
.
grad
is
not
None
assert_close
(
x_for_unshard
.
grad
,
x_for_shard
.
grad
)
target_unshard_gard
=
x_for_unshard
.
grad
if
seq_parallel
is
False
else
torch
.
chunk
(
x_for_unshard
.
grad
.
clone
(),
2
,
dim
=
1
)[
dist
.
get_rank
()]
assert_close
(
target_unshard_gard
,
x_for_shard
.
grad
)
@
parameterize
(
'lazy_init'
,
[
False
,
True
])
@
parameterize
(
'seq_parallel'
,
[
False
,
True
])
@
parameterize
(
'overlap'
,
[
True
])
def
run_dist_linear_test
(
lazy_init
,
seq_parallel
,
overlap
):
check_linear_1d_col
(
lazy_init
,
seq_parallel
,
overlap
)
check_linear_1d_row
(
lazy_init
,
seq_parallel
)
check_linear_col_plus_row
(
lazy_init
,
seq_parallel
,
overlap
)
def
run_dist
(
rank
,
world_size
,
port
):
def
check_dist_linear
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_1d_col
()
check_linear_1d_row
()
check_linear_col_plus_row
()
run_dist_linear_test
()
@
rerun_if_address_is_in_use
()
def
test_linear
():
spawn
(
run_dist
,
nprocs
=
2
)
spawn
(
check_dist_linear
,
nprocs
=
2
)
if
__name__
==
'__main__'
:
...
...
tests/test_shardformer/test_model/_utils.py
View file @
efba0f44
import
copy
import
math
from
contextlib
import
nullcontext
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
...
...
@@ -12,6 +13,7 @@ from torch.optim import Adam, Optimizer
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
HybridParallelPlugin
from
colossalai.booster.plugin.hybrid_parallel_plugin
import
HybridParallelModule
from
colossalai.lazy
import
LazyInitContext
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
...
...
@@ -25,6 +27,7 @@ def build_model(model_fn,
enable_tensor_parallelism
=
True
,
enable_flash_attention
=
False
,
enable_jit_fused
=
False
,
enable_sequence_parallelism
=
False
,
use_lazy_init
:
bool
=
False
):
# create new model
ctx
=
LazyInitContext
()
if
use_lazy_init
else
nullcontext
()
...
...
@@ -38,7 +41,8 @@ def build_model(model_fn,
shard_config
=
ShardConfig
(
enable_fused_normalization
=
enable_fused_normalization
,
enable_tensor_parallelism
=
enable_tensor_parallelism
,
enable_flash_attention
=
enable_flash_attention
,
enable_jit_fused
=
enable_jit_fused
)
enable_jit_fused
=
enable_jit_fused
,
enable_sequence_parallelism
=
enable_sequence_parallelism
)
model_copy
=
copy
.
deepcopy
(
org_model
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
sharded_model
,
shared_params
=
shard_former
.
optimize
(
model_copy
)
...
...
@@ -135,6 +139,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
return
loss
data
=
data_gen_fn
()
if
booster
.
plugin
.
enable_sequence_parallelism
and
booster
.
plugin
.
tp_size
!=
0
:
seq_len
=
data
[
'input_ids'
].
shape
[
1
]
lcm
=
booster
.
plugin
.
tp_size
*
seq_len
//
math
.
gcd
(
booster
.
plugin
.
tp_size
,
seq_len
)
times
=
lcm
//
seq_len
input_shape
=
data
[
'input_ids'
].
shape
for
k
,
v
in
data
.
items
():
if
v
.
shape
==
input_shape
:
data
[
k
]
=
v
.
repeat
(
1
,
times
)
sharded_model
.
train
()
if
booster
.
plugin
.
stage_manager
is
not
None
:
for
k
,
v
in
data
.
items
():
...
...
@@ -177,11 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state
=
org_output
.
last_hidden_state
if
stage_manager
is
None
:
sharded_hidden_state
=
sharded_output
.
last_hidden_state
if
stage_manager
and
stage_manager
.
is_last_stage
():
sharded_hidden_state
=
torch
.
cat
([
output
.
last_hidden_state
for
output
in
sharded_output
[
'outputs'
]],
dim
=
dim
)
sharded_hidden_state
=
sharded_output
[
'outputs'
][
'last_hidden_state'
]
else
:
sharded_hidden_state
=
sharded_output
.
last_hidden_state
assert
torch
.
allclose
(
org_hidden_state
.
float
(),
sharded_hidden_state
.
float
(),
atol
=
atol
,
rtol
=
rtol
),
\
f
"shard model's output hidden state is not equal to origin model's last hidden state
\n
{
org_hidden_state
}
\n
{
sharded_hidden_state
}
"
...
...
@@ -219,6 +232,43 @@ def check_weight(org_model: Module,
f
"shard model weight
{
suffix
}
is not equal to origin model weight
\n
{
org_weight
}
\n
{
sharded_weight
}
"
def
get_grad_tensors_for_check
(
org_model
:
Module
,
sharded_model
:
Module
,
layer_suffix
:
List
[
str
],
tp_group
:
ProcessGroup
=
None
,
dim
:
int
=
0
,
atol
:
float
=
1e-5
,
rtol
:
float
=
1e-3
,
verbose
:
bool
=
False
,
name
:
str
=
None
):
grad_to_check
=
{}
for
suffix
in
layer_suffix
:
org_grad
=
getattr_
(
org_model
,
suffix
).
weight
.
grad
shard_grad
=
getattr_
(
sharded_model
,
suffix
).
weight
.
grad
shard_weight
=
getattr_
(
sharded_model
,
suffix
).
weight
if
is_distributed_tensor
(
shard_weight
)
or
is_customized_distributed_tensor
(
shard_weight
):
shard_grad_list
=
[
torch
.
zeros_like
(
shard_grad
).
to
(
'cuda'
)
for
_
in
range
(
dist
.
get_world_size
(
tp_group
))]
dist
.
all_gather
(
shard_grad_list
,
shard_grad
,
tp_group
)
shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
dim
)
# embedding may be resized when using tensor parallel
if
shard_grad
.
shape
[
0
]
>
org_grad
.
shape
[
0
]:
shard_grad
=
shard_grad
[:
org_grad
.
shape
[
0
],
:]
if
verbose
and
dist
.
get_rank
()
==
0
:
print
(
f
"'
{
suffix
}
' grad:
{
org_grad
}
,
{
shard_grad
}
"
)
grad_to_check
[
suffix
]
=
{
"org_grad"
:
org_grad
.
float
(),
"shard_grad"
:
shard_grad
.
float
(),
"rtol"
:
rtol
,
"atol"
:
atol
}
return
grad_to_check
# used by sam/blip2
def
check_grad
(
org_model
:
Module
,
sharded_model
:
Module
,
layer_suffix
:
List
[
str
],
...
...
@@ -231,7 +281,6 @@ def check_grad(org_model: Module,
org_grad
=
getattr_
(
org_model
,
suffix
).
weight
.
grad
shard_grad
=
getattr_
(
sharded_model
,
suffix
).
weight
.
grad
shard_weight
=
getattr_
(
sharded_model
,
suffix
).
weight
if
is_distributed_tensor
(
shard_weight
)
or
is_customized_distributed_tensor
(
shard_weight
):
shard_grad_list
=
[
torch
.
zeros_like
(
shard_grad
).
to
(
'cuda'
)
for
_
in
range
(
dist
.
get_world_size
(
tp_group
))]
dist
.
all_gather
(
shard_grad_list
,
shard_grad
,
tp_group
)
...
...
@@ -246,3 +295,30 @@ def check_grad(org_model: Module,
assert
torch
.
allclose
(
org_grad
.
float
(),
shard_grad
.
float
(),
rtol
=
rtol
,
atol
=
atol
),
f
"error attribute '
{
suffix
}
', orgin model grad is not equal to shard model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
def
unwrap_model
(
module
:
Module
,
base_model_class_name
:
Optional
[
str
]
=
None
,
base_model_attribute_name
:
Optional
[
str
]
=
None
):
if
isinstance
(
module
,
HybridParallelModule
):
module
=
module
.
unwrap
()
if
base_model_class_name
is
None
:
return
module
if
module
.
__class__
.
__name__
==
base_model_class_name
:
return
module
return
getattr
(
module
,
base_model_attribute_name
,
None
)
def
check_all_grad_tensors
(
check_tensors
):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for
suffix
,
check_info
in
check_tensors
.
items
():
org_grad
=
check_info
[
"org_grad"
]
shard_grad
=
check_info
[
"shard_grad"
]
rtol
=
check_info
[
"rtol"
]
atol
=
check_info
[
"atol"
]
assert
torch
.
allclose
(
org_grad
,
shard_grad
,
atol
=
atol
,
rtol
=
rtol
),
f
"error attribute '
{
suffix
}
', orgin model grad is not equal to shard model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
tests/test_shardformer/test_model/test_shard_bert.py
View file @
efba0f44
...
...
@@ -10,11 +10,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
...
...
@@ -32,42 +34,58 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
output_transform_fn
,
criterion
,
booster
)
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'BertModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
if
org_model
.
__class__
.
__name__
==
'BertModel'
:
bert
=
org_model
sharded_bert
=
sharded_model
.
unwrap
()
else
:
bert
=
org_model
.
bert
sharded_bert
=
sharded_model
.
unwrap
().
bert
bert
=
unwrap_model
(
org_model
,
'BertModel'
,
'bert'
)
sharded_bert
=
unwrap_model
(
sharded_model
,
'BertModel'
,
'bert'
)
col_layer_for_check
=
[
'encoder.layer[0].output.dense'
]
row_layer_for_check
=
[
'embeddings.word_embeddings'
,
'encoder.layer[0].intermediate.dense'
]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check
=
{}
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
check_grad
(
bert
,
sharded_bert
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
check_grad
(
bert
,
sharded_bert
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
# check weights after optimizer.step()
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
col_layer_grads
=
get_grad_tensors_for_check
(
bert
,
sharded_bert
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
row_layer_grads
=
get_grad_tensors_for_check
(
bert
,
sharded_bert
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
# optimizer executes step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'BertModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# check weights
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
5e-3
,
1e-3
else
:
...
...
@@ -75,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
check_weight
(
bert
,
sharded_bert
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -98,6 +119,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
def
run_bert_test
(
test_config
):
...
...
@@ -111,12 +155,50 @@ def run_bert_test(test_config):
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
def
run_bert_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bert'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
def
check_bert
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bert_test
()
def
check_bert_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bert_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -124,5 +206,13 @@ def test_bert():
spawn
(
check_bert
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_bert_3d
():
spawn
(
check_bert_3d
,
8
)
if
__name__
==
"__main__"
:
test_bert
()
test_bert_3d
()
tests/test_shardformer/test_model/test_shard_bloom.py
View file @
efba0f44
...
...
@@ -3,16 +3,19 @@ import torch
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
...
...
@@ -34,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# unwrap model
bloom
=
unwrap_model
(
org_model
,
'BloomModel'
,
'transformer'
)
sharded_bloom
=
unwrap_model
(
sharded_model
,
'BloomModel'
,
'transformer'
)
row_layer_for_check
=
[
'h[0].self_attention.query_key_value'
,
'word_embeddings'
]
col_layer_for_check
=
[
'h[0].self_attention.dense'
]
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check
=
{}
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-6
,
1e-5
else
:
atol
,
rtol
=
5e-3
,
5e-3
row_layer_grads
=
get_grad_tensors_for_check
(
bloom
,
sharded_bloom
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
col_layer_grads
=
get_grad_tensors_for_check
(
bloom
,
sharded_bloom
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
# optimizer executes step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
...
...
@@ -45,28 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
if
org_model
.
__class__
.
__name__
==
'BloomModel'
:
bloom
=
org_model
sharded_bloom
=
sharded_model
.
unwrap
()
else
:
bloom
=
org_model
.
transformer
sharded_bloom
=
sharded_model
.
unwrap
().
transformer
# check grad
row_layer_for_check
=
[
'h[0].self_attention.query_key_value'
,
'word_embeddings'
]
col_layer_for_check
=
[
'h[0].self_attention.dense'
]
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-6
,
1e-5
else
:
atol
,
rtol
=
5e-3
,
5e-3
check_grad
(
bloom
,
sharded_bloom
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
check_grad
(
bloom
,
sharded_bloom
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
# check weights after optimizer.step()
org_optimizer
.
step
()
sharded_optimizer
.
step
()
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
...
...
@@ -74,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol
,
rtol
=
5e-3
,
5e-3
check_weight
(
bloom
,
sharded_bloom
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -97,18 +118,72 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
def
run_bloom_test
(
test_config
):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bloom'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
def
run_bloom_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bloom'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -118,6 +193,12 @@ def check_bloom(rank, world_size, port):
run_bloom_test
()
def
check_bloom_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bloom_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -125,5 +206,13 @@ def test_bloom():
spawn
(
check_bloom
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_bloom_3d
():
spawn
(
check_bloom_3d
,
8
)
if
__name__
==
"__main__"
:
test_bloom
()
test_bloom_3d
()
tests/test_shardformer/test_model/test_shard_chatglm.py
→
tests/test_shardformer/test_model/test_shard_chatglm
2
.py
View file @
efba0f44
...
...
@@ -4,16 +4,19 @@ from torch import distributed as dist
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
...
...
@@ -35,35 +38,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'ChatGLMModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
if
org_model
.
__class__
.
__name__
==
'ChatGLMModel'
:
chatglm_model
=
org_model
shard_chatglm_model
=
sharded_model
.
unwrap
()
else
:
chatglm_model
=
org_model
.
transformer
shard_chatglm_model
=
sharded_model
.
unwrap
().
transformer
chatglm_model
=
unwrap_model
(
org_model
,
'ChatGLMModel'
,
'transformer'
)
shard_chatglm_model
=
unwrap_model
(
sharded_model
,
'ChatGLMModel'
,
'transformer'
)
# check grad
row_layer_for_check
=
[
'encoder.layers[0].self_attention.query_key_value'
,
'embedding.word_embeddings'
]
col_layer_for_check
=
[
'encoder.layers[0].self_attention.dense'
]
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check
=
{}
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-6
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
check
_grad
(
chatglm_model
,
row_layer_grads
=
get_grad_tensors_for_
check
(
chatglm_model
,
shard_chatglm_model
,
row_layer_for_check
,
tp_group
,
...
...
@@ -72,7 +61,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim
=
0
,
verbose
=
False
)
c
heck_grad
(
chatglm_model
,
c
ol_layer_grads
=
get_grad_tensors_for_check
(
chatglm_model
,
shard_chatglm_model
,
col_layer_for_check
,
tp_group
,
...
...
@@ -80,10 +69,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
#
check weights after optimizer.
step
()
#
optimizer executes
step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'ChatGLMModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# check weights
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
...
...
@@ -98,6 +103,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -121,12 +130,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
def
run_chatglm_test
(
test_config
):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_chatglm'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
def
run_chatglm_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_chatglm'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
...
...
@@ -142,6 +194,12 @@ def check_chatglm(rank, world_size, port):
run_chatglm_test
()
def
check_chatglm_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_chatglm_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -149,5 +207,13 @@ def test_chatglm():
spawn
(
check_chatglm
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_chatglm_3d
():
spawn
(
check_chatglm_3d
,
8
)
if
__name__
==
"__main__"
:
test_chatglm
()
test_chatglm_3d
()
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
efba0f44
...
...
@@ -3,18 +3,20 @@ import torch
from
torch
import
distributed
as
dist
import
colossalai
from
colossalai.booster.plugin.hybrid_parallel_plugin
import
HybridParallelModule
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
...
...
@@ -36,6 +38,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# unwrap model
gpt2
=
unwrap_model
(
org_model
,
'GPT2Model'
,
'transformer'
)
sharded_gpt2
=
unwrap_model
(
sharded_model
,
'GPT2Model'
,
'transformer'
)
col_layer_for_check
=
[
'h[0].mlp.c_fc'
]
row_layer_for_check
=
[
'wte'
,
'h[0].mlp.c_proj'
]
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check
=
{}
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
col_layer_grads
=
get_grad_tensors_for_check
(
gpt2
,
sharded_gpt2
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
row_layer_grads
=
get_grad_tensors_for_check
(
gpt2
,
sharded_gpt2
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
# optimizer executes step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
...
...
@@ -48,32 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
def
unwrap
(
module
):
if
isinstance
(
module
,
HybridParallelModule
):
module
=
module
.
unwrap
()
if
module
.
__class__
.
__name__
==
'GPT2Model'
:
return
module
return
module
.
transformer
# unwrap model
gpt2
=
unwrap
(
org_model
)
sharded_gpt2
=
unwrap
(
sharded_model
)
col_layer_for_check
=
[
'h[0].mlp.c_fc'
]
row_layer_for_check
=
[
'wte'
,
'h[0].mlp.c_proj'
]
# check grad
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
check_grad
(
gpt2
,
sharded_gpt2
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
check_grad
(
gpt2
,
sharded_gpt2
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
# check weights after optimizer.step()
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check weights
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
5e-3
,
1e-3
...
...
@@ -81,6 +95,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol
,
rtol
=
5e-3
,
5e-3
check_weight
(
gpt2
,
sharded_gpt2
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -106,12 +124,80 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'enable_sequence_parallelism'
:
True
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
4
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'enable_sequence_parallelism'
:
True
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
@
clear_cache_before_run
()
def
run_gpt2_test
(
test_config
):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
@
clear_cache_before_run
()
def
run_gpt2_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
...
...
@@ -127,10 +213,13 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test
()
# TODO(ver217): fix this
def
check_gpt2_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gpt2_3d_test
()
@
pytest
.
mark
.
skip
(
"this will stuck
in CI"
)
@
pytest
.
mark
.
skip
(
reason
=
"This test will hang
in CI"
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -138,5 +227,13 @@ def test_gpt2():
spawn
(
check_gpt2
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_gpt2_3d
():
spawn
(
check_gpt2_3d
,
8
)
if
__name__
==
"__main__"
:
test_gpt2
()
test_gpt2_3d
()
tests/test_shardformer/test_model/test_shard_llama.py
View file @
efba0f44
...
...
@@ -6,16 +6,19 @@ from torch import distributed as dist
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
...
@@ -39,35 +42,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'LlamaModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
if
org_model
.
__class__
.
__name__
==
'LlamaModel'
:
llama_model
=
org_model
shard_llama_model
=
sharded_model
.
unwrap
()
else
:
llama_model
=
org_model
.
model
shard_llama_model
=
sharded_model
.
unwrap
().
model
llama_model
=
unwrap_model
(
org_model
,
'LlamaModel'
,
'model'
)
shard_llama_model
=
unwrap_model
(
sharded_model
,
'LlamaModel'
,
'model'
)
# check grad
row_layer_for_check
=
[
'layers[0].self_attn.q_proj'
,
'embed_tokens'
]
col_layer_for_check
=
[
'layers[0].self_attn.o_proj'
]
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check
=
{}
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-6
,
1e-4
else
:
atol
,
rtol
=
5e-3
,
5e-3
check
_grad
(
llama_model
,
row_layer_grads
=
get_grad_tensors_for_
check
(
llama_model
,
shard_llama_model
,
row_layer_for_check
,
tp_group
,
...
...
@@ -75,7 +64,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
c
heck_grad
(
llama_model
,
c
ol_layer_grads
=
get_grad_tensors_for_check
(
llama_model
,
shard_llama_model
,
col_layer_for_check
,
tp_group
,
...
...
@@ -83,10 +72,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
#
check weights after optimizer.
step
()
#
optimizer executes
step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'LlamaModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# check weights
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
...
...
@@ -101,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -128,19 +136,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size'
:
1
,
'pp_size'
:
4
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
def
run_llama_test
(
test_config
):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_llama'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
def
run_llama_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_llama'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -150,6 +213,12 @@ def check_llama(rank, world_size, port):
run_llama_test
()
def
check_llama_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_llama_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -157,5 +226,13 @@ def test_llama():
spawn
(
check_llama
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_llama_3d
():
spawn
(
check_llama_3d
,
8
)
if
__name__
==
"__main__"
:
test_llama
()
test_llama_3d
()
tests/test_shardformer/test_model/test_shard_opt.py
View file @
efba0f44
...
...
@@ -6,16 +6,19 @@ from torch import distributed as dist
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
...
@@ -39,34 +42,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'OPTModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
if
org_model
.
__class__
.
__name__
==
'OPTModel'
:
opt_model
=
org_model
shard_opt_model
=
sharded_model
.
unwrap
()
else
:
opt_model
=
org_model
.
model
shard_opt_model
=
sharded_model
.
unwrap
().
model
opt_model
=
unwrap_model
(
org_model
,
'OPTModel'
,
'model'
)
shard_opt_model
=
unwrap_model
(
sharded_model
,
'OPTModel'
,
'model'
)
# check grad
row_layer_for_check
=
[
'decoder.layers[0].self_attn.q_proj'
,
'decoder.embed_tokens'
]
# 'decoder.embed_tokens'
col_layer_for_check
=
[
'decoder.layers[0].self_attn.out_proj'
]
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check
=
{}
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-6
,
1e-3
else
:
atol
,
rtol
=
3
e-2
,
3
e-2
check
_grad
(
opt_model
,
atol
,
rtol
=
4
e-2
,
4
e-2
row_layer_grads
=
get_grad_tensors_for_
check
(
opt_model
,
shard_opt_model
,
row_layer_for_check
,
tp_group
,
...
...
@@ -74,7 +64,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
c
heck_grad
(
opt_model
,
c
ol_layer_grads
=
get_grad_tensors_for_check
(
opt_model
,
shard_opt_model
,
col_layer_for_check
,
tp_group
,
...
...
@@ -82,10 +72,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
#
check weights after optimizer.
step
()
#
optimizer executes
step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'OPTModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# check weights
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-3
,
1e-3
...
...
@@ -100,6 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -123,12 +132,62 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
def
run_opt_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
def
run_opt_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
...
...
@@ -144,6 +203,12 @@ def check_OPTModel(rank, world_size, port):
run_opt_test
()
def
check_opt_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_opt_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -151,5 +216,13 @@ def test_OPTModel():
spawn
(
check_OPTModel
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_opt_3d
():
spawn
(
check_opt_3d
,
8
)
if
__name__
==
'__main__'
:
test_OPTModel
()
test_opt_3d
()
tests/test_shardformer/test_model/test_shard_t5.py
View file @
efba0f44
import
pytest
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
...
...
@@ -9,11 +10,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
...
...
@@ -35,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# unwrap model
t5
=
unwrap_model
(
org_model
)
sharded_t5
=
unwrap_model
(
sharded_model
)
row_layer_for_check
=
[
'shared'
,
'encoder.block[0].layer[0].SelfAttention.q'
]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check
=
{}
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
row_layer_grads
=
get_grad_tensors_for_check
(
t5
,
sharded_t5
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
)
grads_to_check
.
update
(
row_layer_grads
)
# optimizer executes step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
...
...
@@ -47,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
t5
=
org_model
sharded_t5
=
sharded_model
.
unwrap
()
row_layer_for_check
=
[
'shared'
,
'encoder.block[0].layer[0].SelfAttention.q'
]
# check weights and gradients
# check weights
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
check_grad
(
t5
,
sharded_t5
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
)
# check weights after optimizer.step()
org_optimizer
.
step
()
sharded_optimizer
.
step
()
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
5e-4
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
check_weight
(
t5
,
sharded_t5
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
torch
.
cuda
.
empty_cache
()
...
...
@@ -99,17 +115,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size'
:
1
,
'pp_size'
:
4
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
@
clear_cache_before_run
()
def
run_t5_test
(
test_config
):
# TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO(baizhou): add test_config for flash attention & jit operator after supporting
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_t5'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
...
...
@@ -125,12 +160,49 @@ def run_t5_test(test_config):
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp16'
,
'zero_stage'
:
1
,
'initial_scale'
:
1
,
},
])
def
run_t5_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_t5'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_t5
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_t5_test
()
def
check_t5_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_t5_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -138,5 +210,13 @@ def test_t5():
spawn
(
check_t5
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_t5_3d
():
spawn
(
check_t5_3d
,
8
)
if
__name__
==
"__main__"
:
test_t5
()
test_t5_3d
()
tests/test_shardformer/test_model/test_shard_vit.py
View file @
efba0f44
...
...
@@ -9,11 +9,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_
grad
,
check_
all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
unwrap_model
,
)
...
...
@@ -35,35 +37,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'ViTModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
if
org_model
.
__class__
.
__name__
==
'ViTModel'
:
vit_model
=
org_model
shard_vit_model
=
sharded_model
.
unwrap
()
else
:
vit_model
=
org_model
.
vit
shard_vit_model
=
sharded_model
.
unwrap
().
vit
vit_model
=
unwrap_model
(
org_model
,
'ViTModel'
,
'vit'
)
shard_vit_model
=
unwrap_model
(
sharded_model
,
'ViTModel'
,
'vit'
)
# check grad
row_layer_for_check
=
[
'encoder.layer[0].attention.attention.query'
,
'embeddings.patch_embeddings.projection'
]
col_layer_for_check
=
[
'encoder.layer[0].attention.output.dense'
]
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check
=
{}
if
(
stage_manager
is
None
or
stage_manager
.
is_first_stage
())
and
booster
.
plugin
.
zero_stage
==
0
:
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
check
_grad
(
vit_model
,
row_layer_grads
=
get_grad_tensors_for_
check
(
vit_model
,
shard_vit_model
,
row_layer_for_check
,
tp_group
,
...
...
@@ -71,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
c
heck_grad
(
vit_model
,
c
ol_layer_grads
=
get_grad_tensors_for_check
(
vit_model
,
shard_vit_model
,
col_layer_for_check
,
tp_group
,
...
...
@@ -79,10 +68,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
#
check weights after optimizer.
step
()
#
optimizer executes
step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'ViTModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# check weights
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
5e-3
,
1e-3
...
...
@@ -97,9 +101,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim
=
1
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
torch
.
cuda
.
empty_cache
()
#TODO: num_microbatch size = 2 inf loss
@
parameterize
(
'test_config'
,
[{
'tp_size'
:
2
,
'pp_size'
:
2
,
...
...
@@ -120,15 +128,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
},
{
'tp_size'
:
2
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
def
run_vit_test
(
test_config
):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
# TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_vit'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
...
...
@@ -137,12 +166,48 @@ def run_vit_test(test_config):
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_vit_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_vit'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_vit
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_vit_test
()
def
check_vit_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_vit_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
...
...
@@ -150,5 +215,13 @@ def test_vit():
spawn
(
check_vit
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_vit_3d
():
spawn
(
check_vit_3d
,
8
)
if
__name__
==
"__main__"
:
test_vit
()
test_vit_3d
()
tests/test_shardformer/test_model/test_shard_whisper.py
View file @
efba0f44
...
...
@@ -3,6 +3,8 @@ import torch
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
(
assert_hf_output_close
,
clear_cache_before_run
,
...
...
@@ -11,55 +13,205 @@ from colossalai.testing import (
spawn
,
)
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
build_model
,
check_grad
,
run_forward
from
tests.test_shardformer.test_model._utils
import
(
build_model_from_hybrid_plugin
,
check_all_grad_tensors
,
check_loss
,
check_output_hidden_state
,
check_weight
,
get_grad_tensors_for_check
,
run_forward_backward_with_hybrid_plugin
,
)
def
check_forward_backward
(
org_
model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
):
def
check_forward_backward
(
model
_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
):
# check forward
org_output
,
org_loss
,
shard_output
,
shard_loss
=
run_forward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
assert_hf_output_close
(
org_output
,
shard_output
,
ignore_keys
=
'past_key_values'
,
atol
=
1e-5
)
# do backward
org_loss
.
backward
()
shard_loss
.
backward
()
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
org_model
,
org_optimizer
,
sharded_model
,
sharded_optimizer
,
criterion
,
booster
=
\
build_model_from_hybrid_plugin
(
model_fn
,
loss_fn
,
test_config
)
org_loss
,
org_output
,
sharded_loss
,
sharded_output
=
\
run_forward_backward_with_hybrid_plugin
(
org_model
,
sharded_model
,
sharded_optimizer
,
data_gen_fn
,
output_transform_fn
,
criterion
,
booster
)
stage_manager
=
booster
.
plugin
.
stage_manager
tp_group
=
booster
.
plugin
.
tp_group
# unwarp the model
if
org_model
.
__class__
.
__name__
==
'WhisperForConditionalGeneration'
:
whisper
=
org_model
.
model
sharded_whisper
=
sharded_model
.
model
sharded_whisper
=
sharded_model
.
unwrap
().
model
else
:
whisper
=
org_model
sharded_whisper
=
sharded_model
sharded_whisper
=
sharded_model
.
unwrap
()
# check grad
if
org_model
.
__class__
.
__name__
==
'WhisperForAudioClassification'
:
col_layer_for_check
=
[
'encoder.layers[0].self_attn.q_proj'
]
row_layer_for_check
=
[
'encoder.layers[0].self_attn.out_proj'
]
else
:
col_layer_for_check
=
[
'encoder.layers[0].self_attn.q_proj'
,
'decoder.layers[0].self_attn.q_proj'
]
row_layer_for_check
=
[
'encoder.layers[0].self_attn.out_proj'
,
'decoder.layers[0].self_attn.out_proj'
]
check_grad
(
whisper
,
sharded_whisper
,
col_layer_for_check
,
atol
=
1e-6
,
rtol
=
1e-5
,
dim
=
0
,
verbose
=
False
)
check_grad
(
whisper
,
sharded_whisper
,
row_layer_for_check
,
atol
=
1e-6
,
rtol
=
1e-5
,
dim
=
1
,
verbose
=
False
)
col_layer_for_check
=
[
'encoder.layers[0].self_attn.q_proj'
,
# 'decoder.layers[0].self_attn.q_proj'
]
row_layer_for_check
=
[
'encoder.layers[0].self_attn.out_proj'
,
#'decoder.layers[0].self_attn.out_proj'
]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check
=
{}
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
2e-4
,
2e-4
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
row_layer_grads
=
get_grad_tensors_for_check
(
whisper
,
sharded_whisper
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
)
col_layer_grads
=
get_grad_tensors_for_check
(
whisper
,
sharded_whisper
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
)
grads_to_check
.
update
(
col_layer_grads
)
grads_to_check
.
update
(
row_layer_grads
)
# optimizer executes step
org_optimizer
.
step
()
sharded_optimizer
.
step
()
# check last hidden state & loss
if
stage_manager
is
None
or
stage_manager
.
is_last_stage
():
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
2e-4
,
2e-4
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
org_model
.
__class__
.
__name__
==
'WhisperModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# check weights
if
test_config
[
'precision'
]
==
'fp32'
:
atol
,
rtol
=
1e-3
,
1e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
check_weight
(
whisper
,
sharded_whisper
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
,
verbose
=
False
)
check_weight
(
whisper
,
sharded_whisper
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
,
verbose
=
False
)
# check grads
check_all_grad_tensors
(
grads_to_check
)
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'enable_fused_normalization'
,
[
True
,
False
])
@
parameterize
(
'enable_tensor_parallelism'
,
[
True
,
False
])
@
parameterize
(
'enable_flash_attention'
,
[
True
,
False
])
@
parameterize
(
'enable_jit_fused'
,
[
True
,
False
])
def
run_whisper_test
(
enable_fused_normalization
,
enable_tensor_parallelism
,
enable_flash_attention
,
enable_jit_fused
):
#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODO(jianghai) fix fp16
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
4
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
},
{
'tp_size'
:
1
,
'pp_size'
:
4
,
'num_microbatches'
:
4
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
},
# whisper is not supported fp16 for now.
])
def
run_whisper_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_whisper'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
model_fn
,
enable_fused_normalization
=
enable_fused_normalization
,
enable_tensor_parallelism
=
enable_tensor_parallelism
,
enable_flash_attention
=
enable_flash_attention
,
enable_jit_fused
=
enable_jit_fused
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
if
test_config
[
'pp_size'
]
>
2
and
name
==
'transformers_whisper_for_audio_classification'
:
continue
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_whisper_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_whisper'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -69,12 +221,26 @@ def check_whisper(rank, world_size, port):
run_whisper_test
()
def
check_whisper_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_whisper_3d_test
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_whisper
():
spawn
(
check_whisper
,
2
)
spawn
(
check_whisper
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_whisper_3d
():
spawn
(
check_whisper_3d
,
8
)
if
__name__
==
"__main__"
:
test_whisper
()
test_whisper_3d
()
tests/test_utils/test_activation_checkpointing.py
View file @
efba0f44
...
...
@@ -40,7 +40,6 @@ def forward_inplace(x, weight):
return
out
@
pytest
.
mark
.
gpu
@
clear_cache_before_run
()
@
parameterize
(
"use_reentrant"
,
[
True
,
False
])
@
parameterize
(
"cpu_offload"
,
[
True
,
False
])
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment