Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2bed0968
Unverified
Commit
2bed0968
authored
Sep 06, 2022
by
ver217
Committed by
GitHub
Sep 06, 2022
Browse files
[utils] optimize partition_tensor_parallel_state_dict (#1546)
parent
d8a5aded
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
20 deletions
+31
-20
colossalai/utils/checkpointing.py
colossalai/utils/checkpointing.py
+31
-20
No files found.
colossalai/utils/checkpointing.py
View file @
2bed0968
...
...
@@ -29,26 +29,37 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
partition_states
:
dict
=
dict
()):
src_rank
=
gpc
.
get_ranks_in_group
(
parallel_mode
)[
0
]
depth
=
gpc
.
get_world_size
(
parallel_mode
)
if
gpc
.
get_local_rank
(
parallel_mode
)
==
0
:
partitioned_state_list
=
[
dict
()
for
_
in
range
(
depth
)]
for
key
in
list
(
state_dict
.
keys
()):
param
=
state_dict
.
pop
(
key
)
dim
=
dims
.
get
(
key
,
0
)
do_partition
=
partition_states
.
get
(
key
,
True
)
if
do_partition
:
param
=
torch
.
chunk
(
param
,
depth
,
dim
=
dim
)
for
i
,
p
in
enumerate
(
partitioned_state_list
):
p
[
key
]
=
param
[
i
]
if
do_partition
else
param
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
is_rank0
=
gpc
.
get_local_rank
(
parallel_mode
)
==
0
partition_info
=
[
None
]
if
is_rank0
:
partition_info_dict
=
OrderedDict
()
for
key
,
param
in
state_dict
.
items
():
dim
=
dims
[
key
]
is_partitioned
=
partition_states
[
key
]
shape
=
list
(
param
.
shape
)
if
is_partitioned
:
shape
[
dim
]
=
shape
[
dim
]
//
depth
partition_info_dict
[
key
]
=
(
is_partitioned
,
param
.
dtype
,
shape
,
dim
)
partition_info
[
0
]
=
partition_info_dict
dist
.
broadcast_object_list
(
partition_info
,
src_rank
,
group
=
group
)
partitioned_state
=
OrderedDict
()
for
key
,
(
is_partitioned
,
dtype
,
shape
,
dim
)
in
partition_info
[
0
].
items
():
if
is_partitioned
:
output
=
torch
.
empty
(
shape
,
dtype
=
dtype
)
if
is_rank0
:
scatter_list
=
[
t
.
contiguous
()
for
t
in
state_dict
[
key
].
chunk
(
depth
,
dim
)]
else
:
partitioned_state_list
=
[
None
for
_
in
range
(
depth
)]
partitioned_state
=
[
None
]
scatter_object_list
(
partitioned_state
,
partitioned_state_list
,
src
=
src_rank
,
group
=
gpc
.
get_cpu_group
(
parallel_mode
))
return
partitioned_state
[
0
]
scatter_list
=
None
dist
.
scatter
(
output
,
scatter_list
,
src_rank
,
group
=
group
)
else
:
if
is_rank0
:
output
=
state_dict
[
key
]
else
:
output
=
torch
.
empty
(
shape
,
dtype
=
dtype
)
dist
.
broadcast
(
output
,
src_rank
,
group
=
group
)
partitioned_state
[
key
]
=
output
return
partitioned_state
def
gather_tensor_parallel_state_dict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment