Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
90d3aef6
Commit
90d3aef6
authored
Mar 04, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] yet an improved sharded param (#311)
parent
c9e7d958
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
21 deletions
+82
-21
colossalai/zero/sharded_param/__init__.py
colossalai/zero/sharded_param/__init__.py
+2
-2
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+34
-0
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+46
-19
No files found.
colossalai/zero/sharded_param/__init__.py
View file @
90d3aef6
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParam
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParam
,
ShardedParamV2
__all__
=
[
'ShardedParam'
,
'ShardedTensor'
]
__all__
=
[
'ShardedParam'
,
'ShardedTensor'
,
'ShardedParamV2'
]
colossalai/zero/sharded_param/sharded_param.py
View file @
90d3aef6
...
...
@@ -6,6 +6,40 @@ import torch.distributed as dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Union
,
Tuple
,
Optional
import
numpy
class
ShardedParamV2
(
object
):
def
__init__
(
self
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
)
->
None
:
self
.
_data_sharded_tensor
=
ShardedTensor
(
param
.
data
,
process_group
)
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
self
.
_grad_sharded_tensor
=
ShardedTensor
(
param
.
grad
,
process_group
)
param
.
grad
=
None
else
:
self
.
_grad_sharded_tensor
=
None
# make sure the shared param is the only owner of payload
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
@
property
def
data
(
self
):
return
self
.
_data_sharded_tensor
.
payload
@
data
.
setter
def
data
(
self
,
t
:
torch
.
Tensor
):
self
.
_data_sharded_tensor
.
payload
=
t
@
property
def
grad
(
self
):
return
self
.
_grad_sharded_tensor
.
payload
@
grad
.
setter
def
grad
(
self
,
t
:
torch
.
Tensor
):
self
.
_grad_sharded_tensor
.
payload
=
t
class
ShardedParam
(
object
):
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
90d3aef6
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
copy
import
deepcopy
from
functools
import
partial
import
colossalai
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
import
pytest
import
torch
import
torch.multiprocessing
as
mp
...
...
@@ -11,7 +13,7 @@ from colossalai.zero.shard_utils import TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.utils
import
free_port
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
,
allclose
def
run_shard_tensor
(
rank
,
world_size
,
port
):
...
...
@@ -36,28 +38,33 @@ def test_shard_tensor():
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
run_
init_
shard_param
(
rank
,
world_size
,
port
):
def
_
run_shard_param
_v2
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
param
=
torch
.
nn
.
Parameter
(
data
=
torch
.
rand
(
2
,
3
))
sparam
=
ShardedParam
(
param
,
None
,
True
)
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
3
])
del
sparam
param_shape
=
(
2
,
3
)
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
True
,
device
=
torch
.
device
(
'cpu'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
3
])
param
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
2
,
3
))
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
allclose
(
sparam
.
data
,
param_ref
.
data
)
assert
(
param
.
data
.
numel
()
==
1
)
param_shape
=
(
2
,
3
)
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
False
,
device
=
torch
.
device
(
'cpu'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
2
,
3
])
@
pytest
.
mark
.
dist
def
test_shard_param_v2
():
world_size
=
2
run_func
=
partial
(
_run_shard_param_v2
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
run_shard_param_check
(
rank
,
world_size
,
port
):
def
_run_test_shard_param
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
param
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
2
,
3
))
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
print
(
sparam
.
data
)
print
(
param_ref
.
data
)
logger
=
get_dist_logger
()
model
=
Net
()
...
...
@@ -77,12 +84,31 @@ def run_shard_param_check(rank, world_size, port):
@
pytest
.
mark
.
dist
def
test_shard_
shape
():
def
test_shard_
param
():
world_size
=
2
run_func
=
partial
(
run_shard_param
_check
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
_
run_
test_
shard_param
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
def
run_init_shard_param
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
param
=
torch
.
nn
.
Parameter
(
data
=
torch
.
rand
(
2
,
3
))
sparam
=
ShardedParam
(
param
,
None
,
True
)
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
3
])
del
sparam
param_shape
=
(
2
,
3
)
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
True
,
device
=
torch
.
device
(
'cpu'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
3
])
param_shape
=
(
2
,
3
)
sparam
=
ShardedParam
(
param_shape
,
process_group
=
None
,
is_sharded
=
False
,
device
=
torch
.
device
(
'cpu'
))
payload
=
sparam
.
payload
(
torch
.
device
(
'cuda'
))
assert
(
list
(
payload
.
shape
)
==
[
2
,
3
])
@
pytest
.
mark
.
dist
def
test_init_shard_param
():
world_size
=
2
...
...
@@ -92,5 +118,6 @@ def test_init_shard_param():
if
__name__
==
'__main__'
:
test_shard_tensor
()
test_shard_shape
()
test_shard_param
()
test_shard_param_v2
()
test_init_shard_param
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment