Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
508ca36f
Unverified
Commit
508ca36f
authored
Sep 01, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 01, 2023
Browse files
[pipeline] 1f1b schedule receive microbatch size (#4589)
parent
38ccb8b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
7 deletions
+30
-7
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+7
-1
colossalai/pipeline/schedule/one_f_one_b.py
colossalai/pipeline/schedule/one_f_one_b.py
+22
-5
tests/test_pipeline/test_schedule/test_oneF_oneB.py
tests/test_pipeline/test_schedule/test_oneF_oneB.py
+1
-1
No files found.
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
508ca36f
...
...
@@ -247,6 +247,9 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
...
...
@@ -278,6 +281,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_jit_fused
:
bool
=
False
,
enable_sequence_parallelism
:
bool
=
False
,
num_microbatches
:
Optional
[
int
]
=
None
,
microbatch_size
:
Optional
[
int
]
=
None
,
initial_scale
:
float
=
2
**
16
,
min_scale
:
float
=
1
,
growth_factor
:
float
=
2
,
...
...
@@ -324,7 +328,9 @@ class HybridParallelPlugin(PipelinePluginBase):
assert
num_microbatches
is
not
None
,
'num_microbatches must be specified when using pipeline parallelism'
assert
self
.
zero_stage
<=
1
,
'zero stage must be 0 or 1 when using pipeline parallelism'
self
.
stage_manager
=
PipelineStageManager
(
self
.
pg_mesh
,
PP_AXIS
)
self
.
schedule
=
OneForwardOneBackwardSchedule
(
num_microbatches
,
self
.
stage_manager
)
self
.
schedule
=
OneForwardOneBackwardSchedule
(
self
.
stage_manager
,
num_microbatches
=
num_microbatches
,
microbatch_size
=
microbatch_size
)
self
.
tp_group
=
self
.
pg_mesh
.
get_group_along_axis
(
TP_AXIS
)
self
.
dp_group
=
self
.
pg_mesh
.
get_group_along_axis
(
DP_AXIS
)
self
.
pp_group
=
self
.
pg_mesh
.
get_group_along_axis
(
PP_AXIS
)
...
...
colossalai/pipeline/schedule/one_f_one_b.py
View file @
508ca36f
...
...
@@ -17,14 +17,26 @@ from .base import PipelineSchedule
class
OneForwardOneBackwardSchedule
(
PipelineSchedule
):
def
__init__
(
self
,
num_microbatches
:
int
,
stage_manager
:
PipelineStageManager
)
->
None
:
def
__init__
(
self
,
stage_manager
:
PipelineStageManager
,
num_microbatches
:
Optional
[
int
]
=
None
,
microbatch_size
:
Optional
[
int
]
=
None
)
->
None
:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super
().
__init__
(
stage_manager
)
assert
num_microbatches
is
not
None
or
microbatch_size
is
not
None
,
\
"Either num_microbatches or microbatch_size should be provided"
self
.
comm
=
PipelineP2PCommunication
(
stage_manager
)
self
.
num_microbatches
=
num_microbatches
self
.
microbatch_size
=
microbatch_size
self
.
batch
:
Optional
[
Any
]
=
None
self
.
batch_size
:
Optional
[
int
]
=
None
self
.
microbatch_offset
:
Optional
[
int
]
=
None
self
.
microbatch_size
:
Optional
[
int
]
=
None
def
load_batch
(
self
,
data_iter
:
Iterable
,
device
:
Optional
[
torch
.
device
]
=
None
)
->
None
:
"""Load a batch from data iterator.
...
...
@@ -39,9 +51,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self
.
batch
=
batch
self
.
batch_size
=
get_batch_size
(
batch
)
self
.
microbatch_offset
=
0
assert
self
.
batch_size
%
self
.
num_microbatches
==
0
,
\
"Batch size should divided by the number of microbatches"
self
.
microbatch_size
=
self
.
batch_size
//
self
.
num_microbatches
if
self
.
num_microbatches
is
not
None
:
assert
self
.
batch_size
%
self
.
num_microbatches
==
0
,
\
"Batch size should divided by the number of microbatches"
self
.
microbatch_size
=
self
.
batch_size
//
self
.
num_microbatches
else
:
assert
self
.
batch_size
%
self
.
microbatch_size
==
0
,
\
"Batch size should divided by the microbatch size"
self
.
num_microbatches
=
self
.
batch_size
//
self
.
microbatch_size
def
load_micro_batch
(
self
)
->
Any
:
"""Load a micro batch from the current batch.
...
...
tests/test_pipeline/test_schedule/test_oneF_oneB.py
View file @
508ca36f
...
...
@@ -61,7 +61,7 @@ def examine_pp():
DP_DIM
,
PP_DIM
,
TP_DIM
=
0
,
1
,
2
pg_mesh
=
ProcessGroupMesh
(
1
,
world_size
,
1
)
stage_manager
=
PipelineStageManager
(
pg_mesh
,
PP_DIM
)
schedule
=
OneForwardOneBackwardSchedule
(
NUM_MICRO_BATCHS
,
stage_manager
)
schedule
=
OneForwardOneBackwardSchedule
(
stage_manager
,
num_microbatches
=
NUM_MICRO_BATCHS
)
for
idx
,
(
_
,
sub_model
)
in
enumerate
(
pp_model
.
named_children
()):
if
idx
%
(
world_size
)
==
local_rank
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment