Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
88804aee
Commit
88804aee
authored
Mar 14, 2022
by
ver217
Browse files
add bucket tensor shard strategy
parent
aaead33c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
10 deletions
+62
-10
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+20
-7
colossalai/zero/shard_utils/__init__.py
colossalai/zero/shard_utils/__init__.py
+4
-3
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+38
-0
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
88804aee
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.utils
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
...
...
@@ -18,23 +19,32 @@ class ZeroHook(BaseOpHook):
self
.
computing_device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
gather
(
tensor_list
)
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
shard
(
tensor_list
)
for
param
in
module
.
parameters
():
param
.
col_attr
.
remove_torch_payload
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
gather
(
tensor_list
)
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
...
...
@@ -52,10 +62,13 @@ class ZeroHook(BaseOpHook):
param
.
col_attr
.
bwd_count
+=
1
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
shard
(
tensor_list
)
for
param
in
module
.
parameters
():
param
.
col_attr
.
remove_torch_payload
()
def
pre_iter
(
self
):
pass
...
...
colossalai/zero/shard_utils/__init__.py
View file @
88804aee
from
colossalai.zero.shard_utils.base_shard_strategy
import
BaseShardStrategy
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
.base_shard_strategy
import
BaseShardStrategy
from
.bucket_tensor_shard_strategy
import
BucketTensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
]
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
,
'BucketTensorShardStrategy'
]
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
0 → 100644
View file @
88804aee
from
typing
import
List
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
.tensor_shard_strategy
import
TensorShardStrategy
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
if
len
(
tensor_list
)
==
0
:
return
target_device
=
tensor_list
[
0
].
device
dtype
=
tensor_list
[
0
].
dtype
buffer_list
:
List
[
torch
.
Tensor
]
=
[]
tensor_numels
=
[
t
.
payload
.
numel
()
for
t
in
tensor_list
]
buffer_size
=
sum
(
tensor_numels
)
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
local_rank
:
buffer_list
.
append
(
flatten
([
t
.
payload
for
t
in
tensor_list
]).
cuda
(
get_current_device
()))
else
:
buffer_list
.
append
(
torch
.
zeros
(
buffer_size
,
dtype
=
dtype
,
device
=
get_current_device
()))
dist
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list
=
[
buffer
.
to
(
target_device
)
for
buffer
in
buffer_list
]
offset
=
0
for
i
,
t
in
enumerate
(
tensor_list
):
gathered_payload
=
[
buffer
[
offset
:
offset
+
tensor_numels
[
i
]]
for
buffer
in
buffer_list
]
gathered_payload
=
torch
.
cat
(
gathered_payload
)[:
t
.
origin_numel
].
view
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
is_sharded
=
False
offset
+=
tensor_numels
[
i
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment