Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8d3250d7
Unverified
Commit
8d3250d7
authored
Mar 21, 2022
by
ver217
Committed by
GitHub
Mar 21, 2022
Browse files
[zero] ZeRO supports pipeline parallel (#477)
parent
7f5e4592
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
95 deletions
+113
-95
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
...e/gradient_handler/_pipeline_parallel_gradient_handler.py
+6
-4
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+95
-91
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+12
-0
No files found.
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
View file @
8d3250d7
#!/usr/bin/env python
#!/usr/bin/env python
import
torch.distributed
as
dist
from
collections
import
defaultdict
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.registry
import
GRADIENT_HANDLER
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
._base_gradient_handler
import
BaseGradientHandler
from
._base_gradient_handler
import
BaseGradientHandler
from
collections
import
defaultdict
@
GRADIENT_HANDLER
.
register_module
@
GRADIENT_HANDLER
.
register_module
...
@@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
...
@@ -35,7 +37,7 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
for
group
,
group_buckets
in
buckets
.
items
():
for
group
,
group_buckets
in
buckets
.
items
():
for
tp
,
bucket
in
group_buckets
.
items
():
for
tp
,
bucket
in
group_buckets
.
items
():
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
=
_flatten_dense_tensors
(
grads
)
.
to
(
torch
.
cuda
.
current_device
())
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
dist
.
all_reduce
(
coalesced
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
buf
.
copy_
(
synced
)
colossalai/engine/schedule/_pipeline_schedule.py
View file @
8d3250d7
This diff is collapsed.
Click to expand it.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
8d3250d7
...
@@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
...
@@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
raise
NotImplementedError
raise
NotImplementedError
def
__getitem__
(
self
,
idx
:
int
):
assert
isinstance
(
self
.
module
,
nn
.
ModuleList
)
return
self
.
module
[
idx
]
def
__len__
(
self
):
assert
isinstance
(
self
.
module
,
nn
.
ModuleList
)
return
len
(
self
.
module
)
def
__iter__
(
self
):
assert
isinstance
(
self
.
module
,
nn
.
ModuleList
)
return
iter
(
self
.
module
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment