Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
11f54c7b
Unverified
Commit
11f54c7b
authored
Apr 26, 2022
by
Frank Lee
Committed by
GitHub
Apr 26, 2022
Browse files
[doc] improved docstring and assertion messages for the engine module (#871)
parent
1c343826
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
180 additions
and
60 deletions
+180
-60
colossalai/engine/_base_engine.py
colossalai/engine/_base_engine.py
+6
-6
colossalai/engine/gradient_accumulation/_gradient_accumulation.py
...ai/engine/gradient_accumulation/_gradient_accumulation.py
+108
-16
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
...ngine/gradient_handler/_data_parallel_gradient_handler.py
+4
-0
colossalai/engine/gradient_handler/_moe_gradient_handler.py
colossalai/engine/gradient_handler/_moe_gradient_handler.py
+4
-1
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
...e/gradient_handler/_pipeline_parallel_gradient_handler.py
+4
-0
colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
...e/gradient_handler/_sequence_parallel_gradient_handler.py
+4
-0
colossalai/engine/gradient_handler/_zero_gradient_handler.py
colossalai/engine/gradient_handler/_zero_gradient_handler.py
+4
-0
colossalai/engine/paramhooks/_param_hookmgr.py
colossalai/engine/paramhooks/_param_hookmgr.py
+5
-1
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+41
-36
No files found.
colossalai/engine/_base_engine.py
View file @
11f54c7b
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
asyncio.log
import
logger
from
typing
import
List
,
Iterable
from
typing
import
List
,
Iterable
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
torch.optim
import
Optimizer
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -23,7 +21,7 @@ class Engine:
...
@@ -23,7 +21,7 @@ class Engine:
Args:
Args:
model (``torch.nn.Module``): The neural network model.
model (``torch.nn.Module``): The neural network model.
optimizer (``
torch.optim.
Optimizer``): Optimizer for updating the parameters.
optimizer (``
colossalai.nn.optimizer.Colossalai
Optimizer``): Optimizer for updating the parameters.
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
clip_grad_norm (float, optional): The norm of gradient clipping.
clip_grad_norm (float, optional): The norm of gradient clipping.
...
@@ -57,7 +55,7 @@ class Engine:
...
@@ -57,7 +55,7 @@ class Engine:
def
__init__
(
self
,
def
__init__
(
self
,
model
:
Module
,
model
:
Module
,
optimizer
:
Optimizer
,
optimizer
:
"Colossalai
Optimizer
"
,
criterion
:
Optional
[
_Loss
]
=
None
,
criterion
:
Optional
[
_Loss
]
=
None
,
gradient_handlers
:
Optional
[
List
[
BaseGradientHandler
]]
=
None
,
gradient_handlers
:
Optional
[
List
[
BaseGradientHandler
]]
=
None
,
clip_grad_norm
:
float
=
0.0
,
clip_grad_norm
:
float
=
0.0
,
...
@@ -84,9 +82,11 @@ class Engine:
...
@@ -84,9 +82,11 @@ class Engine:
self
.
_ophook_list
=
[]
self
.
_ophook_list
=
[]
else
:
else
:
self
.
_ophook_list
=
ophook_list
self
.
_ophook_list
=
ophook_list
# build schedule
# build schedule
if
schedule
:
if
schedule
:
assert
isinstance
(
schedule
,
BaseSchedule
),
\
f
'expected schedule to be of type BaseSchedule, but got
{
type
(
schedule
)
}
'
self
.
_schedule
=
schedule
self
.
_schedule
=
schedule
else
:
else
:
self
.
_schedule
=
NonPipelineSchedule
()
self
.
_schedule
=
NonPipelineSchedule
()
...
@@ -187,7 +187,7 @@ class Engine:
...
@@ -187,7 +187,7 @@ class Engine:
"""
"""
for
handler
in
self
.
_gradient_handlers
:
for
handler
in
self
.
_gradient_handlers
:
handler
.
handle_gradient
()
handler
.
handle_gradient
()
def
execute_schedule
(
self
,
data_iter
:
Iterable
,
**
kwargs
):
def
execute_schedule
(
self
,
data_iter
:
Iterable
,
**
kwargs
):
"""Run the forward, loss computation, and backward for the model.
"""Run the forward, loss computation, and backward for the model.
Returns a tuple of (output, label, loss).
Returns a tuple of (output, label, loss).
...
...
colossalai/engine/gradient_accumulation/_gradient_accumulation.py
View file @
11f54c7b
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
typing
import
Union
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch
import
Tensor
from
typing
import
Iterable
,
Any
from
typing
import
Iterable
,
Any
,
Tuple
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
torch.nn.parallel.distributed
import
DistributedDataParallel
from
torch.nn.parallel.distributed
import
DistributedDataParallel
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -33,24 +34,54 @@ class GradAccumOptimizer(ColossalaiOptimizer):
...
@@ -33,24 +34,54 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self
.
model
=
model
self
.
model
=
model
self
.
is_torch_ddp
=
isinstance
(
self
.
model
,
DistributedDataParallel
)
self
.
is_torch_ddp
=
isinstance
(
self
.
model
,
DistributedDataParallel
)
def
zero_grad
(
self
,
*
args
,
**
kwargs
):
def
zero_grad
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
Set all gradients to zero.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if
self
.
accumulate_step
==
0
:
if
self
.
accumulate_step
==
0
:
self
.
optim
.
zero_grad
(
*
args
,
**
kwargs
)
self
.
optim
.
zero_grad
(
*
args
,
**
kwargs
)
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
Update the model parameters.
Args:
*args: positional arguments for the optimizer wrapped
**kwargs: keyword arguments for the optimizer wrapped
"""
if
self
.
accumulate_step
<
self
.
accumulate_size
:
if
self
.
accumulate_step
<
self
.
accumulate_size
:
return
None
return
None
else
:
else
:
self
.
accumulate_step
=
0
self
.
accumulate_step
=
0
return
self
.
optim
.
step
(
*
args
,
**
kwargs
)
return
self
.
optim
.
step
(
*
args
,
**
kwargs
)
def
clip_grad_norm
(
self
,
model
:
nn
.
Module
,
max_norm
:
float
):
def
clip_grad_norm
(
self
,
model
:
nn
.
Module
,
max_norm
:
float
)
->
None
:
"""
Clip gradients by norm.
Args:
model (:class:`torch.nn.Module`): a torch module instance
max_norm (float): the max norm for gradient clipping
"""
if
self
.
accumulate_step
<
self
.
accumulate_size
:
if
self
.
accumulate_step
<
self
.
accumulate_size
:
pass
pass
else
:
else
:
self
.
optim
.
clip_grad_norm
(
model
,
max_norm
)
self
.
optim
.
clip_grad_norm
(
model
,
max_norm
)
def
backward
(
self
,
loss
:
Tensor
):
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
self
.
accumulate_step
+=
1
self
.
accumulate_step
+=
1
if
self
.
is_torch_ddp
:
if
self
.
is_torch_ddp
:
...
@@ -62,7 +93,14 @@ class GradAccumOptimizer(ColossalaiOptimizer):
...
@@ -62,7 +93,14 @@ class GradAccumOptimizer(ColossalaiOptimizer):
scaled_loss
=
loss
/
self
.
accumulate_size
scaled_loss
=
loss
/
self
.
accumulate_size
self
.
optim
.
backward
(
scaled_loss
)
self
.
optim
.
backward
(
scaled_loss
)
def
backward_by_grad
(
self
,
tensor
:
Tensor
,
grad
:
Tensor
):
def
backward_by_grad
(
self
,
tensor
:
Tensor
,
grad
:
Tensor
)
->
None
:
"""Execute backward pass given the gradients of the output.
Args:
loss (:class:`torch.Tensor`): the loss value.
grad (:class:`torch.Tensor`): the output gradient.
"""
self
.
accumulate_step
+=
1
self
.
accumulate_step
+=
1
no_sync
=
self
.
is_torch_ddp
and
self
.
accumulate_step
<
self
.
accumulate_size
no_sync
=
self
.
is_torch_ddp
and
self
.
accumulate_step
<
self
.
accumulate_size
...
@@ -84,7 +122,7 @@ class GradAccumDataloader:
...
@@ -84,7 +122,7 @@ class GradAccumDataloader:
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
Args:
Args:
optim
(``Iterable``): Your dataloader object for gradient accumulation.
dataloader
(``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
accumulate_size (int): The number of steps to accumulate gradients.
"""
"""
...
@@ -96,15 +134,15 @@ class GradAccumDataloader:
...
@@ -96,15 +134,15 @@ class GradAccumDataloader:
def
__getattr__
(
self
,
__name
:
str
)
->
Any
:
def
__getattr__
(
self
,
__name
:
str
)
->
Any
:
return
getattr
(
self
.
dataloader
,
__name
)
return
getattr
(
self
.
dataloader
,
__name
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
self
.
steps_per_epoch
return
self
.
steps_per_epoch
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterable
:
self
.
_cur_step
=
0
self
.
_cur_step
=
0
self
.
_dataiter
=
iter
(
self
.
dataloader
)
self
.
_dataiter
=
iter
(
self
.
dataloader
)
return
self
return
self
def
__next__
(
self
)
->
Any
:
def
__next__
(
self
)
->
Union
[
Tensor
,
Tuple
[
Tensor
]]
:
if
self
.
_cur_step
<
self
.
steps_per_epoch
:
if
self
.
_cur_step
<
self
.
steps_per_epoch
:
self
.
_cur_step
+=
1
self
.
_cur_step
+=
1
...
@@ -137,13 +175,30 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
...
@@ -137,13 +175,30 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self
.
accumulate_step
=
0
self
.
accumulate_step
=
0
@
staticmethod
@
staticmethod
def
compute_effective_steps_per_epoch
(
dataloader
:
Iterable
,
accumulate_size
:
int
):
def
compute_effective_steps_per_epoch
(
dataloader
:
Iterable
,
accumulate_size
:
int
)
->
int
:
"""
Computes the number of effective training iterations. An effective iteration is defined
as the the aggregation of <accumulate_size> iterations. For examples, if accumulate_size = 4,
then 4 iterations are considered as one effective iteration.
Args:
dataloader (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
return
len
(
dataloader
)
//
accumulate_size
return
len
(
dataloader
)
//
accumulate_size
def
__getattr__
(
self
,
__name
:
str
)
->
Any
:
def
__getattr__
(
self
,
__name
:
str
)
->
Any
:
return
getattr
(
self
.
lr_scheduler
,
__name
)
return
getattr
(
self
.
lr_scheduler
,
__name
)
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
Update the learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self
.
accumulate_step
+=
1
self
.
accumulate_step
+=
1
if
self
.
accumulate_step
<
self
.
accumulate_size
:
if
self
.
accumulate_step
<
self
.
accumulate_size
:
pass
pass
...
@@ -151,19 +206,52 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
...
@@ -151,19 +206,52 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self
.
accumulate_step
=
0
self
.
accumulate_step
=
0
self
.
lr_scheduler
.
step
(
*
args
,
**
kwargs
)
self
.
lr_scheduler
.
step
(
*
args
,
**
kwargs
)
def
get_lr
(
self
):
def
get_lr
(
self
)
->
Tensor
:
"""
Compute the next learning rate.
Returns:
Tensor: the upcoming learning rate.
"""
return
self
.
lr_scheduler
.
get_lr
()
return
self
.
lr_scheduler
.
get_lr
()
def
get_last_lr
(
self
):
def
get_last_lr
(
self
)
->
Tensor
:
"""
Returns the current learning rate.
Returns:
Tensor: the current learning rate.
"""
return
self
.
lr_scheduler
.
get_last_lr
()
return
self
.
lr_scheduler
.
get_last_lr
()
def
print_lr
(
self
,
*
args
,
**
kwargs
):
def
print_lr
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
Print he learning rate.
Args:
*args: positional arguments for the lr scheduler wrapped.
**kwargs: keyword arguments for the lr scheduler wrapped.
"""
self
.
lr_scheduler
.
print_lr
(
*
args
,
**
kwargs
)
self
.
lr_scheduler
.
print_lr
(
*
args
,
**
kwargs
)
def
state_dict
(
self
)
->
dict
:
def
state_dict
(
self
)
->
dict
:
"""
Returns the states of the lr scheduler as dictionary.
Returns:
dict: the states of the lr scheduler.
"""
return
self
.
lr_scheduler
.
state_dict
()
return
self
.
lr_scheduler
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
"""
Load the states of the lr scheduler from a dictionary object.
Returns:
dict: the states of the lr scheduler.
"""
self
.
lr_scheduler
.
load_state_dict
(
state_dict
)
self
.
lr_scheduler
.
load_state_dict
(
state_dict
)
...
@@ -188,7 +276,11 @@ class GradAccumGradientHandler:
...
@@ -188,7 +276,11 @@ class GradAccumGradientHandler:
self
.
accumulate_size
=
accumulate_size
self
.
accumulate_size
=
accumulate_size
self
.
accumulate_step
=
0
self
.
accumulate_step
=
0
def
handle_gradient
(
self
):
def
handle_gradient
(
self
)
->
None
:
"""
Handle gradients reduction only in the last gradient accumulation step.
"""
self
.
accumulate_step
+=
1
self
.
accumulate_step
+=
1
if
self
.
accumulate_step
<
self
.
accumulate_size
:
if
self
.
accumulate_step
<
self
.
accumulate_size
:
pass
pass
...
...
colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
View file @
11f54c7b
...
@@ -12,6 +12,10 @@ class DataParallelGradientHandler(BaseGradientHandler):
...
@@ -12,6 +12,10 @@ class DataParallelGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among a data parallel group.
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
"""
def
handle_gradient
(
self
):
def
handle_gradient
(
self
):
...
...
colossalai/engine/gradient_handler/_moe_gradient_handler.py
View file @
11f54c7b
...
@@ -14,6 +14,10 @@ class MoeGradientHandler(BaseGradientHandler):
...
@@ -14,6 +14,10 @@ class MoeGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among a data parallel group.
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
"""
def
__init__
(
self
,
model
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
optimizer
=
None
):
...
@@ -29,7 +33,6 @@ class MoeGradientHandler(BaseGradientHandler):
...
@@ -29,7 +33,6 @@ class MoeGradientHandler(BaseGradientHandler):
if
global_data
>
1
:
if
global_data
>
1
:
epsize_param_dict
=
get_moe_epsize_param_dict
(
self
.
_model
)
epsize_param_dict
=
get_moe_epsize_param_dict
(
self
.
_model
)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
# reduce gradients for all parameters in data parallelism
...
...
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
View file @
11f54c7b
...
@@ -18,6 +18,10 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
...
@@ -18,6 +18,10 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among all sub pipeline parallel groups.
:func:`handle_gradient` among all sub pipeline parallel groups.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
"""
def
handle_gradient
(
self
):
def
handle_gradient
(
self
):
...
...
colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
View file @
11f54c7b
...
@@ -12,6 +12,10 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
...
@@ -12,6 +12,10 @@ class SequenceParallelGradientHandler(BaseGradientHandler):
:func:`handle_gradient` among a data parallel group.
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
"""
def
handle_gradient
(
self
):
def
handle_gradient
(
self
):
...
...
colossalai/engine/gradient_handler/_zero_gradient_handler.py
View file @
11f54c7b
...
@@ -8,6 +8,10 @@ class ZeROGradientHandler(BaseGradientHandler):
...
@@ -8,6 +8,10 @@ class ZeROGradientHandler(BaseGradientHandler):
A all-reduce collective communication will be operated in
A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
:func:`handle_gradient` among a data parallel group.
This class is specialized with ZeRO optimization.
This class is specialized with ZeRO optimization.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
"""
def
handle_gradient
(
self
):
def
handle_gradient
(
self
):
...
...
colossalai/engine/paramhooks/_param_hookmgr.py
View file @
11f54c7b
...
@@ -28,7 +28,11 @@ class BaseParamHookMgr(object):
...
@@ -28,7 +28,11 @@ class BaseParamHookMgr(object):
handle
=
p
.
register_hook
(
functools
.
partial
(
hook_call
,
p
))
handle
=
p
.
register_hook
(
functools
.
partial
(
hook_call
,
p
))
p
.
_base_param_hook
=
handle
p
.
_base_param_hook
=
handle
def
remove_hooks
(
self
):
def
remove_hooks
(
self
)
->
None
:
"""
Remove hooks from model parameters.
"""
for
p
in
self
.
_param_list
:
for
p
in
self
.
_param_list
:
if
p
.
requires_grad
and
hasattr
(
p
,
'_base_param_hook'
):
if
p
.
requires_grad
and
hasattr
(
p
,
'_base_param_hook'
):
p
.
_base_param_hook
.
remove
()
p
.
_base_param_hook
.
remove
()
colossalai/engine/schedule/_pipeline_schedule.py
View file @
11f54c7b
...
@@ -81,6 +81,9 @@ class PipelineSchedule(BaseSchedule):
...
@@ -81,6 +81,9 @@ class PipelineSchedule(BaseSchedule):
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
,
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
,
scatter_gather_tensors
:
bool
=
False
):
scatter_gather_tensors
:
bool
=
False
):
super
().
__init__
(
batch_data_process_func
=
batch_data_process_func
)
super
().
__init__
(
batch_data_process_func
=
batch_data_process_func
)
assert
num_microbatches
>
0
,
f
'expected num_microbatches to be larger then 1, but got
{
num_microbatches
}
'
self
.
num_microbatches
=
num_microbatches
self
.
num_microbatches
=
num_microbatches
self
.
dtype
=
torch
.
float
self
.
dtype
=
torch
.
float
self
.
tensor_shape
=
tensor_shape
self
.
tensor_shape
=
tensor_shape
...
@@ -150,7 +153,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -150,7 +153,7 @@ class PipelineSchedule(BaseSchedule):
else
:
else
:
return
model
(
input_tensor
,
**
batch_data
)
return
model
(
input_tensor
,
**
batch_data
)
def
forward_step
(
self
,
engine
,
input_tensor
,
return_tensors
,
return_output_label
=
True
,
accum_loss
=
None
):
def
_
forward_step
(
self
,
engine
,
input_tensor
,
return_tensors
,
return_output_label
=
True
,
accum_loss
=
None
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
Returns output tensor. This is a helper function and can be ignored by users.
...
@@ -186,7 +189,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -186,7 +189,7 @@ class PipelineSchedule(BaseSchedule):
)
)
return
output_tensor
return
output_tensor
def
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
def
_
backward_step
(
self
,
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step through the passed-in output tensor. If it is the last stage, the
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
Returns the gradients with respect to the input tensor (None if first stage).
...
@@ -267,11 +270,11 @@ class PipelineSchedule(BaseSchedule):
...
@@ -267,11 +270,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
input_tensor
=
comm
.
recv_forward
(
ft_shape
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
output_tensor
=
self
.
forward_step
(
engine
,
output_tensor
=
self
.
_
forward_step
(
engine
,
input_tensor
,
input_tensor
,
return_tensors
,
return_tensors
,
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
accum_loss
=
accum_loss
)
accum_loss
=
accum_loss
)
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_last_rank
(
ParallelMode
.
PIPELINE
):
bt_shape
=
output_tensor
.
shape
bt_shape
=
output_tensor
.
shape
fs_checker
=
comm
.
send_tensor_meta
(
output_tensor
,
fs_checker
)
fs_checker
=
comm
.
send_tensor_meta
(
output_tensor
,
fs_checker
)
...
@@ -295,11 +298,11 @@ class PipelineSchedule(BaseSchedule):
...
@@ -295,11 +298,11 @@ class PipelineSchedule(BaseSchedule):
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
self
.
forward_step
(
engine
,
output_tensor
=
self
.
_
forward_step
(
engine
,
input_tensor
,
input_tensor
,
return_tensors
,
return_tensors
,
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
accum_loss
=
accum_loss
)
accum_loss
=
accum_loss
)
if
forward_only
:
if
forward_only
:
comm
.
send_forward
(
output_tensor
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
comm
.
send_forward
(
output_tensor
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
...
@@ -323,7 +326,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -323,7 +326,7 @@ class PipelineSchedule(BaseSchedule):
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor_grad
=
self
.
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
self
.
_
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
...
@@ -344,7 +347,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -344,7 +347,7 @@ class PipelineSchedule(BaseSchedule):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
input_tensor_grad
=
self
.
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
self
.
_
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
comm
.
send_backward
(
input_tensor_grad
,
scatter_gather_tensors
=
self
.
scatter_gather_tensors
)
...
@@ -358,8 +361,8 @@ class PipelineSchedule(BaseSchedule):
...
@@ -358,8 +361,8 @@ class PipelineSchedule(BaseSchedule):
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
class
InterleavedPipelineSchedule
(
PipelineSchedule
):
def
__init__
(
self
,
def
__init__
(
self
,
num_microbatches
,
num_microbatches
:
int
,
num_model_chunks
,
num_model_chunks
:
int
,
batch_data_process_func
:
Callable
=
None
,
batch_data_process_func
:
Callable
=
None
,
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
,
tensor_shape
:
Union
[
torch
.
Size
,
List
[
int
],
Tuple
[
int
]]
=
None
,
scatter_gather_tensors
:
bool
=
False
):
scatter_gather_tensors
:
bool
=
False
):
...
@@ -378,6 +381,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -378,6 +381,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
"""
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
assert
num_microbatches
%
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
==
0
,
\
'num_microbatches must be an integer multiple of pipeline parallel world size'
'num_microbatches must be an integer multiple of pipeline parallel world size'
assert
isinstance
(
num_model_chunks
,
int
)
and
num_model_chunks
>
0
,
\
f
'expected num_model_chunks to be an integer and larger than 0, but got
{
num_model_chunks
}
'
super
().
__init__
(
num_microbatches
,
super
().
__init__
(
num_microbatches
,
batch_data_process_func
=
batch_data_process_func
,
batch_data_process_func
=
batch_data_process_func
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
...
@@ -409,13 +414,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -409,13 +414,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self
.
microbatch_offset
[
model_chunk_id
]
+=
self
.
microbatch_size
self
.
microbatch_offset
[
model_chunk_id
]
+=
self
.
microbatch_size
return
self
.
_move_to_device
(
data
),
self
.
_move_to_device
(
label
)
return
self
.
_move_to_device
(
data
),
self
.
_move_to_device
(
label
)
def
forward_step
(
self
,
def
_
forward_step
(
self
,
engine
,
engine
,
model_chunk_id
,
model_chunk_id
,
input_tensor
,
input_tensor
,
return_tensors
,
return_tensors
,
return_output_label
=
True
,
return_output_label
=
True
,
accum_loss
=
None
):
accum_loss
=
None
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
Returns output tensor. This is a helper function and can be ignored by users.
...
@@ -522,7 +527,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -522,7 +527,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
def
_
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
forward_step())."""
...
@@ -535,12 +540,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -535,12 +540,12 @@ class InterleavedPipelineSchedule(PipelineSchedule):
len
(
output_tensors
[
model_chunk_id
]):
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
self
.
forward_step
(
engine
,
output_tensor
=
self
.
_
forward_step
(
engine
,
model_chunk_id
,
model_chunk_id
,
input_tensor
,
input_tensor
,
return_tensors
,
return_tensors
,
return_output_label
=
return_output_label
,
return_output_label
=
return_output_label
,
accum_loss
=
accum_loss
)
accum_loss
=
accum_loss
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
@@ -550,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -550,7 +555,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return
output_tensor
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
def
_
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
backward_step())."""
...
@@ -563,7 +568,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -563,7 +568,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
self
.
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
input_tensor_grad
=
self
.
_
backward_step
(
engine
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
return
input_tensor_grad
...
@@ -578,7 +583,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -578,7 +583,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
)
output_tensor
=
_
forward_step_helper
(
k
)
if
not
gpc
.
is_pipeline_last_stage
():
if
not
gpc
.
is_pipeline_last_stage
():
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
output_tensor_shapes
[
model_chunk_id
]
=
output_tensor
.
shape
send_tensor_shape_flags
[
model_chunk_id
]
=
comm
.
send_tensor_meta
(
output_tensor
,
send_tensor_shape_flags
[
model_chunk_id
]
=
comm
.
send_tensor_meta
(
output_tensor
,
...
@@ -633,11 +638,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -633,11 +638,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
for
k
in
range
(
num_microbatches_remaining
):
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
output_tensor
=
_
forward_step_helper
(
forward_k
)
# Backward pass.
# Backward pass.
backward_k
=
k
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
input_tensor_grad
=
_
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# and output_tensor_grad.
...
@@ -708,7 +713,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -708,7 +713,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
comm
.
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
],
comm
.
recv_backward
(
output_tensor_shapes
[
num_model_chunks
-
1
],
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
scatter_gather_tensors
=
self
.
scatter_gather_tensors
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
_
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
recv_next
=
True
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
gpc
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment