Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
86d22581
Unverified
Commit
86d22581
authored
Sep 05, 2023
by
Bin Jia
Committed by
GitHub
Sep 05, 2023
Browse files
[shardformer] Add overlap optional for HybridParallelPlugin (#4615)
* add optional overlap for plugin * remove fixed todo
parent
a39a5c66
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+3
-1
colossalai/shardformer/layer/_operation.py
colossalai/shardformer/layer/_operation.py
+0
-2
No files found.
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
86d22581
...
@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -280,6 +280,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention
:
bool
=
False
,
enable_flash_attention
:
bool
=
False
,
enable_jit_fused
:
bool
=
False
,
enable_jit_fused
:
bool
=
False
,
enable_sequence_parallelism
:
bool
=
False
,
enable_sequence_parallelism
:
bool
=
False
,
enable_sequence_overlap
:
bool
=
False
,
num_microbatches
:
Optional
[
int
]
=
None
,
num_microbatches
:
Optional
[
int
]
=
None
,
microbatch_size
:
Optional
[
int
]
=
None
,
microbatch_size
:
Optional
[
int
]
=
None
,
initial_scale
:
float
=
2
**
16
,
initial_scale
:
float
=
2
**
16
,
...
@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -341,7 +342,8 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_fused_normalization
=
self
.
enable_fused_normalization
,
enable_fused_normalization
=
self
.
enable_fused_normalization
,
enable_flash_attention
=
self
.
enable_flash_attention
,
enable_flash_attention
=
self
.
enable_flash_attention
,
enable_jit_fused
=
self
.
enable_jit_fused
,
enable_jit_fused
=
self
.
enable_jit_fused
,
enable_sequence_parallelism
=
enable_sequence_parallelism
)
enable_sequence_parallelism
=
enable_sequence_parallelism
,
enable_sequence_overlap
=
enable_sequence_overlap
)
self
.
amp_config
=
dict
(
self
.
amp_config
=
dict
(
initial_scale
=
initial_scale
,
initial_scale
=
initial_scale
,
growth_factor
=
growth_factor
,
growth_factor
=
growth_factor
,
...
...
colossalai/shardformer/layer/_operation.py
View file @
86d22581
...
@@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
...
@@ -180,7 +180,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
overlap
=
ctx
.
overlap
overlap
=
ctx
.
overlap
if
not
overlap
:
if
not
overlap
:
# TODO: overlap SP input with gradient computation
input_parallel
=
_gather
(
input_
,
dim
,
process_group
)
input_parallel
=
_gather
(
input_
,
dim
,
process_group
)
total_input
=
input_parallel
total_input
=
input_parallel
...
@@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
...
@@ -191,7 +190,6 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
])
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
])
total_input
=
total_input
.
view
(
-
1
,
total_input
.
shape
[
-
1
])
total_input
=
total_input
.
view
(
-
1
,
total_input
.
shape
[
-
1
])
# TODO: overlap SP input with gradient computation
if
ctx
.
async_grad_reduce_scatter
:
if
ctx
.
async_grad_reduce_scatter
:
# Asynchronous reduce-scatter
# Asynchronous reduce-scatter
input_list
=
[
input_list
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment