Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
74f77e31
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "05023ecfeeee662e8d4fd416ac321f3a129b1a90"
Commit
74f77e31
authored
Mar 04, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] a shard strategy in granularity of tensor (#307)
parent
80364c76
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
10 deletions
+56
-10
colossalai/zero/shard_utils/__init__.py
colossalai/zero/shard_utils/__init__.py
+4
-0
colossalai/zero/shard_utils/base_shard_strategy.py
colossalai/zero/shard_utils/base_shard_strategy.py
+27
-0
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+18
-0
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+7
-10
No files found.
colossalai/zero/shard_utils/__init__.py
0 → 100644
View file @
74f77e31
from
colossalai.zero.shard_utils.base_shard_strategy
import
BaseShardStrategy
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
]
colossalai/zero/shard_utils/base_shard_strategy.py
0 → 100644
View file @
74f77e31
from
abc
import
ABC
,
abstractmethod
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
import
torch.distributed
as
dist
from
typing
import
List
,
Optional
class
BaseShardStrategy
(
ABC
):
def
__init__
(
self
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
self
.
process_group
=
process_group
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
super
().
__init__
()
@
abstractmethod
def
shard
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
r
"""
sharded the memory of tensor on multiple processes.
"""
pass
@
abstractmethod
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
r
"""
duplicate tensor payload on each processes.
"""
pass
colossalai/zero/shard_utils/tensor_shard_strategy.py
0 → 100644
View file @
74f77e31
from
colossalai.zero.shard_utils
import
BaseShardStrategy
import
torch.distributed
as
dist
from
typing
import
List
,
Optional
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
class
TensorShardStrategy
(
BaseShardStrategy
):
def
__init__
(
self
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
super
().
__init__
(
process_group
)
def
shard
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
for
t
in
tensor_list
:
t
.
shard
()
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
for
t
in
tensor_list
:
t
.
gather
()
tests/test_zero_data_parallel/test_shard_param.py
View file @
74f77e31
...
@@ -7,6 +7,8 @@ import colossalai
...
@@ -7,6 +7,8 @@ import colossalai
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
...
@@ -18,19 +20,14 @@ def run_shard_tensor(rank, world_size, port):
...
@@ -18,19 +20,14 @@ def run_shard_tensor(rank, world_size, port):
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
t
.
shard
()
shard_strategy
=
TensorShardStrategy
(
process_group
=
None
)
# The shape is flattened
assert
list
(
t
.
shape
)
==
[
6
]
# Do nothing
t
.
shard
()
assert
list
(
t
.
shape
)
==
[
6
]
t
.
gather
()
# test shard strategy
shard_strategy
.
shard
([
t
])
assert
list
(
t
.
shape
)
==
[
6
]
shard_strategy
.
gather
([
t
])
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
t
.
payload
=
torch
.
zeros
(
world_size
*
2
,
3
)
assert
torch
.
sum
(
t
.
payload
).
cpu
()
==
0
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_shard_tensor
():
def
test_shard_tensor
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment