Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
376533a5
Unverified
Commit
376533a5
authored
Aug 28, 2023
by
Jianghai
Committed by
GitHub
Aug 28, 2023
Browse files
[shardformer] zero1+pp and the corresponding tests (#4517)
* pause * finish pp+zero1 * Update test_shard_vit.py
parent
44eab2b2
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
109 additions
and
35 deletions
+109
-35
colossalai/pipeline/schedule/one_f_one_b.py
colossalai/pipeline/schedule/one_f_one_b.py
+1
-2
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+7
-2
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+9
-0
tests/test_shardformer/test_model/test_shard_bloom.py
tests/test_shardformer/test_model/test_shard_bloom.py
+9
-0
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+9
-0
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+9
-0
tests/test_shardformer/test_model/test_shard_opt.py
tests/test_shardformer/test_model/test_shard_opt.py
+9
-0
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+9
-0
tests/test_shardformer/test_model/test_shard_vit.py
tests/test_shardformer/test_model/test_shard_vit.py
+10
-1
tests/test_shardformer/test_model/test_shard_whisper.py
tests/test_shardformer/test_model/test_shard_whisper.py
+37
-30
No files found.
colossalai/pipeline/schedule/one_f_one_b.py
View file @
376533a5
...
@@ -128,11 +128,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
...
@@ -128,11 +128,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
"""
micro_batch
=
self
.
load_micro_batch
()
micro_batch
=
self
.
load_micro_batch
()
# for the first stage, input_obj is None
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj
=
model_forward
(
model
,
micro_batch
,
input_obj
)
output_obj
=
model_forward
(
model
,
micro_batch
,
input_obj
)
if
self
.
stage_manager
.
is_last_stage
():
if
self
.
stage_manager
.
is_last_stage
():
loss
=
criterion
(
output_obj
,
micro_batch
)
/
self
.
num_microbatches
loss
=
criterion
(
output_obj
,
micro_batch
)
/
self
.
num_microbatches
if
accum_loss
is
not
None
:
if
accum_loss
is
not
None
:
accum_loss
.
add_
(
loss
.
detach
())
accum_loss
.
add_
(
loss
.
detach
())
...
@@ -158,7 +158,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
...
@@ -158,7 +158,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Retain the grad on the input_obj.
# Retain the grad on the input_obj.
tree_map
(
retain_grad
,
input_obj
)
tree_map
(
retain_grad
,
input_obj
)
# Backward pass.
# Backward pass.
if
output_obj_grad
is
None
:
if
output_obj_grad
is
None
:
optimizer
.
backward
(
output_obj
)
optimizer
.
backward
(
output_obj
)
...
...
colossalai/zero/low_level/low_level_optim.py
View file @
376533a5
...
@@ -316,7 +316,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -316,7 +316,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
assert
not
(
self
.
_partition_grads
and
not
self
.
require_grad_sync
),
\
assert
not
(
self
.
_partition_grads
and
not
self
.
require_grad_sync
),
\
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if
self
.
mixed_precision_mixin
is
not
None
:
if
self
.
mixed_precision_mixin
is
not
None
:
loss
=
self
.
mixed_precision_mixin
.
pre_backward
(
loss
)
loss
=
self
.
mixed_precision_mixin
.
pre_backward
(
loss
)
...
@@ -333,6 +332,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -333,6 +332,13 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self
.
zero_grad
()
self
.
zero_grad
()
def
backward_by_grad
(
self
,
tensor
,
grad
):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
if
self
.
mixed_precision_mixin
is
not
None
:
grad
=
self
.
mixed_precision_mixin
.
pre_backward_by_grad
(
tensor
,
grad
)
torch
.
autograd
.
backward
(
tensor
,
grad
)
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""
"""
Set parameter gradients to zero. If set_to_none = True, gradient
Set parameter gradients to zero. If set_to_none = True, gradient
...
@@ -358,7 +364,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -358,7 +364,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
assert
closure
is
None
,
'closure is not supported by step()'
assert
closure
is
None
,
'closure is not supported by step()'
if
not
self
.
require_grad_sync
:
if
not
self
.
require_grad_sync
:
return
return
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
376533a5
...
@@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
def
run_bert_test
(
test_config
):
def
run_bert_test
(
test_config
):
...
...
tests/test_shardformer/test_model/test_shard_bloom.py
View file @
376533a5
...
@@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
def
run_bloom_test
(
test_config
):
def
run_bloom_test
(
test_config
):
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
376533a5
...
@@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
@
clear_cache_before_run
()
@
clear_cache_before_run
()
def
run_gpt2_test
(
test_config
):
def
run_gpt2_test
(
test_config
):
...
...
tests/test_shardformer/test_model/test_shard_llama.py
View file @
376533a5
...
@@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
def
run_llama_test
(
test_config
):
def
run_llama_test
(
test_config
):
...
...
tests/test_shardformer/test_model/test_shard_opt.py
View file @
376533a5
...
@@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
def
run_opt_test
(
test_config
):
def
run_opt_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
...
...
tests/test_shardformer/test_model/test_shard_t5.py
View file @
376533a5
...
@@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
True
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
@
clear_cache_before_run
()
@
clear_cache_before_run
()
def
run_t5_test
(
test_config
):
def
run_t5_test
(
test_config
):
...
...
tests/test_shardformer/test_model/test_shard_vit.py
View file @
376533a5
...
@@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if
org_model
.
__class__
.
__name__
==
'ViTModel'
:
if
org_model
.
__class__
.
__name__
==
'ViTModel'
:
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_output_hidden_state
(
org_output
,
sharded_output
,
stage_manager
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
check_loss
(
org_loss
,
sharded_loss
,
atol
=
atol
,
rtol
=
rtol
)
# unwrap model
# unwrap model
...
@@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
#TODO: num_microbatch size = 2 inf loss
@
parameterize
(
'test_config'
,
[{
@
parameterize
(
'test_config'
,
[{
'tp_size'
:
2
,
'tp_size'
:
2
,
'pp_size'
:
2
,
'pp_size'
:
2
,
...
@@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage'
:
2
,
'zero_stage'
:
2
,
'precision'
:
'fp16'
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
'initial_scale'
:
1
},
{
'tp_size'
:
1
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'zero_stage'
:
1
,
'precision'
:
'fp16'
,
'initial_scale'
:
1
}])
}])
def
run_vit_test
(
test_config
):
def
run_vit_test
(
test_config
):
...
...
tests/test_shardformer/test_model/test_shard_whisper.py
View file @
376533a5
...
@@ -112,10 +112,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -112,10 +112,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODO(jianghai) fix fp16
# TODO(jianghai) fix fp16
#TODO fix WhisperForConditionalGeneration enable jit fused operator
@
parameterize
(
@
parameterize
(
'test_config'
,
[{
'test_config'
,
[
{
'tp_size'
:
2
,
'tp_size'
:
2
,
'pp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
2
,
'num_microbatches'
:
2
,
...
@@ -123,26 +125,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -123,26 +125,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'use_lazy_init'
:
True
,
'use_lazy_init'
:
True
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
'initial_scale'
:
1
,
},
{
},
{
'tp_size'
:
1
,
'tp_size'
:
1
,
'pp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'num_microbatches'
:
4
,
'use_lazy_init'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
'initial_scale'
:
1
,
},
{
},
{
'tp_size'
:
4
,
'tp_size'
:
4
,
'pp_size'
:
1
,
'pp_size'
:
1
,
'enable_all_optimization'
:
True
,
'enable_all_optimization'
:
True
,
'use_lazy_init'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
,
},
{
},
{
'tp_size'
:
1
,
'tp_size'
:
1
,
'pp_size'
:
4
,
'pp_size'
:
4
,
'num_microbatches'
:
4
,
'num_microbatches'
:
4
,
'use_lazy_init'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'precision'
:
'fp32'
,
}])
},
# whisper is not supported fp16 for now.
])
def
run_whisper_test
(
test_config
):
def
run_whisper_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_whisper'
)
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_whisper'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment