Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9afb5c8b
Commit
9afb5c8b
authored
Mar 02, 2022
by
Frank Lee
Browse files
fixed typo in ShardParam (#294)
parent
27155b85
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
19 deletions
+23
-19
colossalai/zero/shard_param/shard_param.py
colossalai/zero/shard_param/shard_param.py
+23
-19
No files found.
colossalai/zero/shard_param/shard_param.py
View file @
9afb5c8b
from
enum
import
Enum
from
enum
import
Enum
from
optparse
import
Option
import
torch
import
torch
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
import
torch.distributed
as
dist
import
torch.distributed
as
dist
class
TensorType
(
Enum
):
class
TensorType
(
Enum
):
GRAD
=
1
GRAD
=
1
DATA
=
2
DATA
=
2
class
ShardParam
(
object
):
class
ShardParam
(
object
):
r
"""
r
"""
A wrapper to torch.nn.Parameter. Shard a param
A wrapper to torch.nn.Parameter. Shard a param
on different processes.
on different processes.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
param
:
torch
.
nn
.
Parameter
,
tensor_type
:
TensorType
=
TensorType
.
DATA
,
tensor_type
:
TensorType
=
TensorType
.
DATA
,
process_group
=
None
,
process_group
=
None
,
)
->
None
:
)
->
None
:
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
...
@@ -27,25 +30,25 @@ class ShardParam(object):
...
@@ -27,25 +30,25 @@ class ShardParam(object):
self
.
_payload_numel
=
None
self
.
_payload_numel
=
None
self
.
_origin_shape
=
param
.
shape
self
.
_origin_shape
=
param
.
shape
self
.
_origin_numel
=
param
.
numel
()
self
.
_origin_numel
=
param
.
numel
()
self
.
is_shared
=
False
self
.
is_shar
d
ed
=
False
def
payload
(
self
,
target_device
:
torch
.
device
):
def
payload
(
self
,
target_device
:
torch
.
device
):
return
self
.
_param_payload
.
to
(
target_device
)
return
self
.
_param_payload
.
to
(
target_device
)
def
shard
(
self
):
def
shard
(
self
):
r
"""
r
"""
Distributed the payload of param to all processes.
Distributed the payload of param to all processes.
"""
"""
if
self
.
is_shared
:
if
self
.
is_shar
d
ed
:
return
return
self
.
_param_payload
,
_
=
get_shard
(
self
.
_param_payload
,
self
.
local_rank
,
self
.
world_size
)
self
.
_param_payload
,
_
=
get_shard
(
self
.
_param_payload
,
self
.
local_rank
,
self
.
world_size
)
self
.
is_shared
=
True
self
.
is_shar
d
ed
=
True
def
gather
(
self
):
def
gather
(
self
):
r
"""
r
"""
Collect the payload of param from different processes to process of local rank.
Collect the payload of param from different processes to process of local rank.
"""
"""
if
not
self
.
is_shared
:
if
not
self
.
is_shar
d
ed
:
return
return
buffer_list
=
[]
buffer_list
=
[]
...
@@ -56,8 +59,9 @@ class ShardParam(object):
...
@@ -56,8 +59,9 @@ class ShardParam(object):
else
:
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
).
cuda
())
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
).
cuda
())
torch
.
distributed
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
,
async_op
=
False
)
torch
.
distributed
.
all_gather
(
buffer_list
,
print
(
buffer_list
)
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
,
async_op
=
False
)
self
.
_param_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
self
.
_origin_numel
).
view
(
self
.
_origin_shape
)
self
.
_param_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
self
.
_origin_numel
).
view
(
self
.
_origin_shape
)
self
.
is_shared
=
False
self
.
is_sharded
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment