Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9afb5c8b
Commit
9afb5c8b
authored
Mar 02, 2022
by
Frank Lee
Browse files
fixed typo in ShardParam (#294)
parent
27155b85
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
19 deletions
+23
-19
colossalai/zero/shard_param/shard_param.py
colossalai/zero/shard_param/shard_param.py
+23
-19
No files found.
colossalai/zero/shard_param/shard_param.py
View file @
9afb5c8b
from
enum
import
Enum
from
optparse
import
Option
import
torch
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
import
torch.distributed
as
dist
class
TensorType
(
Enum
):
GRAD
=
1
DATA
=
2
class
ShardParam
(
object
):
r
"""
A wrapper to torch.nn.Parameter. Shard a param
on different processes.
"""
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
tensor_type
:
TensorType
=
TensorType
.
DATA
,
process_group
=
None
,
)
->
None
:
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
tensor_type
:
TensorType
=
TensorType
.
DATA
,
process_group
=
None
,
)
->
None
:
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
...
...
@@ -27,27 +30,27 @@ class ShardParam(object):
self
.
_payload_numel
=
None
self
.
_origin_shape
=
param
.
shape
self
.
_origin_numel
=
param
.
numel
()
self
.
is_shared
=
False
def
payload
(
self
,
target_device
:
torch
.
device
):
self
.
is_shar
d
ed
=
False
def
payload
(
self
,
target_device
:
torch
.
device
):
return
self
.
_param_payload
.
to
(
target_device
)
def
shard
(
self
):
r
"""
Distributed the payload of param to all processes.
"""
if
self
.
is_shared
:
if
self
.
is_shar
d
ed
:
return
self
.
_param_payload
,
_
=
get_shard
(
self
.
_param_payload
,
self
.
local_rank
,
self
.
world_size
)
self
.
is_shared
=
True
self
.
is_shar
d
ed
=
True
def
gather
(
self
):
r
"""
Collect the payload of param from different processes to process of local rank.
"""
if
not
self
.
is_shared
:
if
not
self
.
is_shar
d
ed
:
return
buffer_list
=
[]
payload_numel
=
self
.
_param_payload
.
numel
()
for
i
in
range
(
self
.
world_size
):
...
...
@@ -55,9 +58,10 @@ class ShardParam(object):
buffer_list
.
append
(
self
.
_param_payload
.
cuda
())
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
).
cuda
())
torch
.
distributed
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
,
async_op
=
False
)
print
(
buffer_list
)
self
.
_param_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
self
.
_origin_numel
).
view
(
self
.
_origin_shape
)
self
.
is_shared
=
False
torch
.
distributed
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
,
async_op
=
False
)
self
.
_param_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
self
.
_origin_numel
).
view
(
self
.
_origin_shape
)
self
.
is_sharded
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment