Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c008d4ad
Unverified
Commit
c008d4ad
authored
Feb 20, 2023
by
Michelle
Committed by
GitHub
Feb 20, 2023
Browse files
[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)
parent
58abde28
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
6 deletions
+8
-6
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+8
-6
No files found.
colossalai/engine/schedule/_pipeline_schedule.py
View file @
c008d4ad
...
...
@@ -4,8 +4,9 @@
import
inspect
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
colossalai.communication
as
comm
import
torch.cuda
import
colossalai.communication
as
comm
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
...
...
@@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule):
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
Example:
# this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output
...
...
@@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
def
pre_processing
(
self
,
engine
):
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism
model
=
engine
.
model
if
isinstance
(
model
,
NaiveAMPModel
):
...
...
@@ -229,7 +231,7 @@ class PipelineSchedule(BaseSchedule):
return
data
,
label
def
_forward_step
(
self
,
engine
,
input_obj
,
return_tensors
,
return_output_label
=
True
,
accum_loss
=
None
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
...
...
@@ -266,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
return
output_obj
def
_backward_step
(
self
,
engine
,
input_obj
,
output_obj
,
output_obj_grad
):
"""Backward step through the passed-in output tensor. If it is the last stage, the
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
...
...
@@ -511,7 +513,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return_tensors
,
return_output_label
=
True
,
accum_loss
=
None
):
"""Forward step for passed-in model. If it is the first stage, the input tensor
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment