Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a65cbb7e
Unverified
Commit
a65cbb7e
authored
Apr 15, 2022
by
HELSON
Committed by
GitHub
Apr 15, 2022
Browse files
[zero] refactor shard and gather operation (#773)
parent
5a1a095b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
colossalai/zero/shard_utils/commons.py
colossalai/zero/shard_utils/commons.py
+5
-3
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+5
-7
No files found.
colossalai/zero/shard_utils/commons.py
View file @
a65cbb7e
...
...
@@ -14,7 +14,9 @@ def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.T
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
shard
=
chunks
[
rank
].
clone
()
if
num_to_pad
>
0
:
shard
=
F
.
pad
(
shard
,
[
0
,
num_to_pad
])
shard
=
torch
.
zeros_like
(
chunks
[
0
])
length
=
chunks
[
rank
].
size
(
0
)
shard_temp
=
shard
[:
length
]
shard_temp
.
copy_
(
chunks
[
rank
])
return
shard
,
num_to_pad
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
a65cbb7e
...
...
@@ -43,18 +43,16 @@ class TensorShardStrategy(BaseShardStrategy):
if
not
t
.
is_sharded
:
return
target_device
=
t
.
device
buffer_list
=
[]
payload_numel
=
t
.
payload
.
numel
()
world_size
=
dist
.
get_world_size
(
process_group
)
rank
=
dist
.
get_rank
(
process_group
)
for
i
in
range
(
world_size
):
if
i
==
rank
:
buffer_list
.
append
(
t
.
payload
.
cuda
(
get_current_device
()))
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
,
device
=
get_current_device
()))
buffer
=
torch
.
empty
(
payload_numel
*
world_size
,
dtype
=
t
.
payload
.
dtype
,
device
=
get_current_device
())
buffer_list
=
list
(
torch
.
chunk
(
buffer
,
chunks
=
world_size
,
dim
=
0
))
buffer_list
[
rank
].
copy_
(
t
.
payload
)
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
,
async_op
=
False
)
gathered_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
)
,
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
gathered_payload
=
torch
.
narrow
(
buffer
,
0
,
0
,
t
.
origin_numel
).
reshape
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
colo_model_data_tensor_move_inline
(
t
,
target_device
)
t
.
is_sharded
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment