Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8d3250d7
Unverified
Commit
8d3250d7
authored
Mar 21, 2022
by
ver217
Committed by
GitHub
Mar 21, 2022
Browse files
[zero] ZeRO supports pipeline parallel (#477)
parent
7f5e4592
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
95 deletions
+113
-95
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
...e/gradient_handler/_pipeline_parallel_gradient_handler.py
+6
-4
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+95
-91
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+12
-0
No files found.
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
View file @
8d3250d7
#!/usr/bin/env python
#!/usr/bin/env python
import
torch.distributed
as
dist
from
collections
import
defaultdict
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.registry
import
GRADIENT_HANDLER
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
._base_gradient_handler
import
BaseGradientHandler
from
._base_gradient_handler
import
BaseGradientHandler
from
collections
import
defaultdict
@
GRADIENT_HANDLER
.
register_module
@
GRADIENT_HANDLER
.
register_module
...
@@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
...
@@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
for
group
,
group_buckets
in
buckets
.
items
():
for
group
,
group_buckets
in
buckets
.
items
():
for
tp
,
bucket
in
group_buckets
.
items
():
for
tp
,
bucket
in
group_buckets
.
items
():
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
=
_flatten_dense_tensors
(
grads
)
.
to
(
torch
.
cuda
.
current_device
())
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
buf
.
copy_
(
synced
)
colossalai/engine/schedule/_pipeline_schedule.py
View file @
8d3250d7
...
@@ -12,7 +12,8 @@ from colossalai.core import global_context as gpc
...
@@ -12,7 +12,8 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ShardedOptimizer
,
ShardedModel
from
colossalai.zero
import
ShardedModel
,
ShardedOptimizer
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
._base_schedule
import
BaseSchedule
from
._base_schedule
import
BaseSchedule
...
@@ -79,8 +80,8 @@ class PipelineSchedule(BaseSchedule):
...
@@ -79,8 +80,8 @@ class PipelineSchedule(BaseSchedule):
def
_get_data_slice
(
self
,
data
,
offset
):
def
_get_data_slice
(
self
,
data
,
offset
):
if
isinstance
(
data
,
torch
.
Tensor
):
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
[
offset
:
offset
+
self
.
microbatch_size
]
return
data
[
offset
:
offset
+
self
.
microbatch_size
]
el
se
:
el
if
isinstance
(
data
,
dict
)
:
return
{
k
:
v
[
offset
:
offset
+
self
.
microbatch_size
]
for
k
,
v
in
data
.
items
()}
return
{
k
:
v
[
offset
:
offset
+
self
.
microbatch_size
]
for
k
,
v
in
data
.
items
()}
def
load_micro_batch
(
self
):
def
load_micro_batch
(
self
):
...
@@ -92,11 +93,9 @@ class PipelineSchedule(BaseSchedule):
...
@@ -92,11 +93,9 @@ class PipelineSchedule(BaseSchedule):
def
pre_processing
(
self
,
engine
):
def
pre_processing
(
self
,
engine
):
# TODO: remove this after testing new zero with pipeline parallelism
# TODO: remove this after testing new zero with pipeline parallelism
if
isinstance
(
engine
.
optimizer
,
ShardedOptimizer
)
or
isinstance
(
engine
.
model
,
ShardedModel
):
if
isinstance
(
engine
.
optimizer
,
ShardedOptimizer
)
or
isinstance
(
engine
.
model
,
ShardedModel
):
raise
TypeError
(
raise
TypeError
(
"Pipeline schedule is currently not compatible with ZeRO"
)
"Pipeline schedule is currently not compatible with ZeRO"
)
model
=
engine
.
model
model
=
engine
.
model
if
isinstance
(
model
,
NaiveAMPModel
):
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)
):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
model
=
model
.
model
model
=
model
.
model
sig
=
inspect
.
signature
(
model
.
forward
)
sig
=
inspect
.
signature
(
model
.
forward
)
...
@@ -107,6 +106,8 @@ class PipelineSchedule(BaseSchedule):
...
@@ -107,6 +106,8 @@ class PipelineSchedule(BaseSchedule):
def
_call_engine
(
model
,
input_tensor
,
batch_data
):
def
_call_engine
(
model
,
input_tensor
,
batch_data
):
if
isinstance
(
model
,
NaiveAMPModel
):
if
isinstance
(
model
,
NaiveAMPModel
):
sig
=
inspect
.
signature
(
model
.
model
.
forward
)
sig
=
inspect
.
signature
(
model
.
model
.
forward
)
elif
isinstance
(
model
,
ShardedModelV2
):
sig
=
inspect
.
signature
(
model
.
module
.
forward
)
else
:
else
:
sig
=
inspect
.
signature
(
model
.
forward
)
sig
=
inspect
.
signature
(
model
.
forward
)
if
isinstance
(
batch_data
,
torch
.
Tensor
):
if
isinstance
(
batch_data
,
torch
.
Tensor
):
...
@@ -162,9 +163,11 @@ class PipelineSchedule(BaseSchedule):
...
@@ -162,9 +163,11 @@ class PipelineSchedule(BaseSchedule):
return
output_tensor
return
output_tensor
else
:
else
:
assert
isinstance
(
assert
isinstance
(
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
return
output_tensor
return
output_tensor
def
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
...
@@ -203,12 +206,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -203,12 +206,7 @@ class PipelineSchedule(BaseSchedule):
return
input_tensor_grad
return
input_tensor_grad
def
forward_backward_step
(
self
,
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
...
@@ -231,10 +229,9 @@ class PipelineSchedule(BaseSchedule):
...
@@ -231,10 +229,9 @@ class PipelineSchedule(BaseSchedule):
'The argument
\'
return_loss
\'
has to be True when
\'
forward_only
\'
is False, but got False.'
'The argument
\'
return_loss
\'
has to be True when
\'
forward_only
\'
is False, but got False.'
self
.
load_batch
(
data_iter
)
self
.
load_batch
(
data_iter
)
num_warmup_microbatches
=
\
num_warmup_microbatches
=
\
(
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
-
(
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
-
1
)
-
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
self
.
num_microbatches
)
self
.
num_microbatches
)
num_microbatches_remaining
=
self
.
num_microbatches
-
num_warmup_microbatches
num_microbatches_remaining
=
self
.
num_microbatches
-
num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
# Input, output tensors only need to be saved when doing backward passes
...
@@ -257,13 +254,14 @@ class PipelineSchedule(BaseSchedule):
...
@@ -257,13 +254,14 @@ class PipelineSchedule(BaseSchedule):
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
ft_shape
=
comm
.
recv_tensor_meta
(
ft_shape
)
ft_shape
=
comm
.
recv_tensor_meta
(
ft_shape
)
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
output_tensor
=
self
.
forward_step
(
output_tensor
=
self
.
forward_step
(
engine
,
engine
,
input_tensor
,
return_tensors
,
input_tensor
,
return_output_label
=
return_output_label
,
return_tensors
,
accum_loss
=
accum_loss
return_output_label
=
return_output_label
,
)
accum_loss
=
accum_loss
)
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
bt_shape
=
output_tensor
.
shape
bt_shape
=
output_tensor
.
shape
fs_checker
=
comm
.
send_tensor_meta
(
output_tensor
,
fs_checker
)
fs_checker
=
comm
.
send_tensor_meta
(
output_tensor
,
fs_checker
)
...
@@ -279,28 +277,32 @@ class PipelineSchedule(BaseSchedule):
...
@@ -279,28 +277,32 @@ class PipelineSchedule(BaseSchedule):
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_first_rank
(
ParallelMode
.
PIPELINE
):
ft_shape
=
comm
.
recv_tensor_meta
(
ft_shape
)
ft_shape
=
comm
.
recv_tensor_meta
(
ft_shape
)
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
self
.
forward_step
(
output_tensor
=
self
.
forward_step
(
engine
,
engine
,
input_tensor
,
return_tensors
,
input_tensor
,
return_output_label
=
return_output_label
,
return_tensors
,
accum_loss
=
accum_loss
return_output_label
=
return_output_label
,
)
accum_loss
=
accum_loss
)
if
forward_only
:
if
forward_only
:
comm
.
send_forward
(
output_tensor
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
comm
.
send_forward
(
output_tensor
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
if
not
last_iteration
:
if
not
last_iteration
:
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
else
:
else
:
output_tensor_grad
=
comm
.
send_forward_recv_backward
(
output_tensor_grad
=
comm
.
send_forward_recv_backward
(
output_tensor
,
output_tensor
,
bt_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
bt_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Add input_tensor and output_tensor to end of list.
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
...
@@ -311,18 +313,16 @@ class PipelineSchedule(BaseSchedule):
...
@@ -311,18 +313,16 @@ class PipelineSchedule(BaseSchedule):
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
self
.
backward_step
(
input_tensor_grad
=
self
.
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
else
:
else
:
input_tensor
=
comm
.
send_backward_recv_forward
(
input_tensor
=
comm
.
send_backward_recv_forward
(
input_tensor_grad
,
input_tensor_grad
,
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
ft_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
# Run cooldown backward passes.
# Run cooldown backward passes.
if
not
forward_only
:
if
not
forward_only
:
...
@@ -330,14 +330,11 @@ class PipelineSchedule(BaseSchedule):
...
@@ -330,14 +330,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
comm
.
recv_backward
(
bt_shape
,
dtype
=
self
.
dtype
,
output_tensor_grad
=
comm
.
recv_backward
(
bt_shape
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
input_tensor_grad
=
self
.
backward_step
(
input_tensor_grad
=
self
.
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
...
@@ -349,6 +346,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -349,6 +346,7 @@ class PipelineSchedule(BaseSchedule):
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
def
__init__
(
self
,
def
__init__
(
self
,
num_microbatches
,
num_microbatches
,
num_model_chunks
,
num_model_chunks
,
...
@@ -372,21 +370,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -372,21 +370,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
"""
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
'num_microbatches must be an integer multiple of pipeline parallel world size'
'num_microbatches must be an integer multiple of pipeline parallel world size'
super
().
__init__
(
num_microbatches
,
batch_data_process_func
=
batch_data_process_func
,
super
().
__init__
(
num_microbatches
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
scatter_gather_tensors
)
batch_data_process_func
=
batch_data_process_func
,
tensor_shape
=
tensor_shape
,
scatter_gather_tensors
=
scatter_gather_tensors
)
gpc
.
set_virtual_pipeline_parallel_size
(
num_model_chunks
)
gpc
.
set_virtual_pipeline_parallel_size
(
num_model_chunks
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
self
.
num_model_chunks
=
num_model_chunks
self
.
num_model_chunks
=
num_model_chunks
def
pre_processing
(
self
,
engine
):
def
pre_processing
(
self
,
engine
):
if
isinstance
(
engine
.
optimizer
,
(
ZeroRedundancyOptimizer_Level_2
,
ZeroRedundancyOptimizer_Level_3
)):
if
isinstance
(
engine
.
model
,
ShardedModelV2
):
raise
TypeError
(
self
.
dtype
=
torch
.
half
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
elif
isinstance
(
engine
.
model
[
0
],
NaiveAMPModel
):
)
if
isinstance
(
engine
.
model
[
0
],
NaiveAMPModel
):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
for
model
in
engine
.
model
:
for
model
in
engine
.
model
:
if
isinstance
(
model
,
NaiveAMPModel
):
if
isinstance
(
model
,
NaiveAMPModel
):
model
=
model
.
model
model
=
model
.
model
...
@@ -405,7 +401,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -405,7 +401,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self
.
microbatch_offset
[
model_chunk_id
]
+=
self
.
microbatch_size
self
.
microbatch_offset
[
model_chunk_id
]
+=
self
.
microbatch_size
return
self
.
_move_to_device
(
data
),
self
.
_move_to_device
(
label
)
return
self
.
_move_to_device
(
data
),
self
.
_move_to_device
(
label
)
def
forward_step
(
self
,
engine
,
model_chunk_id
,
input_tensor
,
return_tensors
,
return_output_label
=
True
,
accum_loss
=
None
):
def
forward_step
(
self
,
engine
,
model_chunk_id
,
input_tensor
,
return_tensors
,
return_output_label
=
True
,
accum_loss
=
None
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
Returns output tensor. This is a helper function and can be ignored by users.
...
@@ -425,9 +427,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -425,9 +427,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return
output_tensor
return
output_tensor
else
:
else
:
assert
isinstance
(
assert
isinstance
(
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
output_tensor
,
torch
.
Tensor
),
'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
f
'Global rank
{
gpc
.
get_global_rank
()
}
, pipeline rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
forward output tensor
{
output_tensor
.
shape
}
, dtype
{
output_tensor
.
dtype
}
'
)
return
output_tensor
return
output_tensor
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
def
forward_backward_step
(
self
,
engine
,
data_iter
,
forward_only
=
False
,
return_loss
=
True
,
return_output_label
=
True
):
...
@@ -488,10 +492,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -488,10 +492,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
else
:
else
:
num_warmup_microbatches
=
\
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
num_microbatches
-
num_warmup_microbatches
...
@@ -516,8 +518,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -516,8 +518,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
len
(
output_tensors
[
model_chunk_id
]):
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
self
.
forward_step
(
engine
,
model_chunk_id
,
input_tensor
,
output_tensor
=
self
.
forward_step
(
engine
,
return_tensors
,
return_output_label
=
return_output_label
,
accum_loss
=
accum_loss
)
model_chunk_id
,
input_tensor
,
return_tensors
,
return_output_label
=
return_output_label
,
accum_loss
=
accum_loss
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
@@ -548,18 +554,20 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -548,18 +554,20 @@ class InterleavedPipelineSchedule(PipelineSchedule):
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
gpc
.
set_virtual_pipeline_parallel_rank
(
0
)
if
not
gpc
.
is_pipeline_first_stage
():
if
not
gpc
.
is_pipeline_first_stage
():
input_tensor_shapes
[
0
]
=
comm
.
recv_tensor_meta
(
input_tensor_shapes
[
0
])
input_tensor_shapes
[
0
]
=
comm
.
recv_tensor_meta
(
input_tensor_shapes
[
0
])
input_tensors
[
0
].
append
(
comm
.
recv_forward
(
input_tensor_shapes
[
0
],
dtype
=
self
.
dtype
,
input_tensors
[
0
].
append
(
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
comm
.
recv_forward
(
input_tensor_shapes
[
0
],
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
)
output_tensor
=
forward_step_helper
(
k
)
if
not
gpc
.
is_pipeline_last_stage
():
if
not
gpc
.
is_pipeline_last_stage
():
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
send_tensor_shape_flags
[
model_chunk_id
]
=
comm
.
send_tensor_meta
(
send_tensor_shape_flags
[
model_chunk_id
]
=
comm
.
send_tensor_meta
(
output_tensor
,
output_tensor
,
send_tensor_shape_flags
[
model_chunk_id
])
send_tensor_shape_flags
[
model_chunk_id
])
# Determine if tensor should be received from previous stage.
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
recv_prev
=
True
if
gpc
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
gpc
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
if
next_forward_model_chunk_id
==
0
:
...
@@ -584,7 +592,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -584,7 +592,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next
=
True
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
recv_next
=
False
output_shape
=
output_tensor_shapes
[
num_model_chunks
-
1
]
if
recv_next
else
None
output_shape
=
output_tensor_shapes
[
num_model_chunks
-
1
]
if
recv_next
else
None
input_tensor
,
output_tensor_grad
=
\
input_tensor
,
output_tensor_grad
=
\
comm
.
send_forward_backward_recv_forward_backward
(
comm
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
...
@@ -593,7 +601,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -593,7 +601,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
else
:
input_tensor
=
\
input_tensor
=
\
comm
.
send_forward_recv_forward
(
comm
.
send_forward_recv_forward
(
...
@@ -634,26 +642,23 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -634,26 +642,23 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_prev
=
True
recv_prev
=
True
if
gpc
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
gpc
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
recv_prev
=
False
next_forward_model_chunk_id
+=
1
next_forward_model_chunk_id
+=
1
else
:
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
forward
=
True
)
recv_next
=
True
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
recv_next
=
False
next_backward_model_chunk_id
-=
1
next_backward_model_chunk_id
-=
1
else
:
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
# before the start of the for loop.
...
@@ -677,17 +682,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -677,17 +682,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if
recv_prev
:
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
not
forward_only
:
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
comm
.
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
],
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
comm
.
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
],
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
...
@@ -696,12 +701,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -696,12 +701,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next
=
False
recv_next
=
False
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
output_shape
=
output_tensor_shapes
[
next_backward_model_chunk_id
]
if
recv_next
else
None
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
comm
.
send_backward_recv_backward
(
comm
.
send_backward_recv_backward
(
input_tensor_grad
,
input_tensor_grad
,
output_shape
,
output_shape
,
recv_next
=
recv_next
,
recv_next
=
recv_next
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
if
len
(
return_tensors
)
>
0
:
if
len
(
return_tensors
)
>
0
:
output
,
label
=
pack_return_tensors
(
return_tensors
)
output
,
label
=
pack_return_tensors
(
return_tensors
)
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
8d3250d7
...
@@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
...
@@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
raise
NotImplementedError
raise
NotImplementedError
def
__getitem__
(
self
,
idx
:
int
):
assert
isinstance
(
self
.
module
,
nn
.
ModuleList
)
return
self
.
module
[
idx
]
def
__len__
(
self
):
assert
isinstance
(
self
.
module
,
nn
.
ModuleList
)
return
len
(
self
.
module
)
def
__iter__
(
self
):
assert
isinstance
(
self
.
module
,
nn
.
ModuleList
)
return
iter
(
self
.
module
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment