Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
36f9a74a
Commit
36f9a74a
authored
Mar 04, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix sharded param hook and unit test
parent
001ca624
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
49 additions
and
66 deletions
+49
-66
colossalai/engine/ophooks/_shard_param_ophook.py
colossalai/engine/ophooks/_shard_param_ophook.py
+4
-9
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-12
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+9
-5
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+27
-34
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+3
-4
No files found.
colossalai/engine/ophooks/_shard_param_ophook.py
View file @
36f9a74a
import
torch
import
torch
import
torch.distributed
as
dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
from
.
import
BaseOpHook
...
@@ -21,29 +20,25 @@ class ShardParamHook(BaseOpHook):
...
@@ -21,29 +20,25 @@ class ShardParamHook(BaseOpHook):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
gather
()
param
.
ca_attr
.
gather
()
if
dist
.
get_rank
()
==
0
:
param
.
data
=
param
.
ca_attr
.
payload
()
print
(
f
'
{
param
.
_name
}
pre fwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
shard
()
param
.
ca_attr
.
shard
()
if
dist
.
get_rank
()
==
0
:
param
.
data
=
param
.
ca_attr
.
payload
()
print
(
f
'
{
param
.
_name
}
post fwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
gather
()
param
.
ca_attr
.
gather
()
if
dist
.
get_rank
()
==
0
:
param
.
data
=
param
.
ca_attr
.
payload
()
print
(
f
'
{
param
.
_name
}
pre bwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'ca_attr'
)
assert
hasattr
(
param
,
'ca_attr'
)
param
.
ca_attr
.
shard
()
param
.
ca_attr
.
shard
()
if
dist
.
get_rank
()
==
0
:
param
.
data
=
param
.
ca_attr
.
payload
()
print
(
f
'
{
param
.
_name
}
post bwd shape
{
param
.
ca_attr
.
payload
(
"cpu"
).
shape
}
'
)
def
pre_iter
(
self
):
def
pre_iter
(
self
):
pass
pass
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
36f9a74a
...
@@ -6,8 +6,7 @@ import torch.distributed as dist
...
@@ -6,8 +6,7 @@ import torch.distributed as dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine.ophooks
import
(
ShardGradHook
,
ShardParamHook
,
from
colossalai.engine.ophooks
import
(
ShardGradHook
,
ShardParamHook
,
register_ophooks_recursively
)
register_ophooks_recursively
)
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
...
@@ -109,6 +108,10 @@ class ShardedModelV2(nn.Module):
...
@@ -109,6 +108,10 @@ class ShardedModelV2(nn.Module):
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
continue
continue
p
.
_sharded_grad
.
write_back
()
p
.
_sharded_grad
.
write_back
()
# In case some post bwd hook is not fired
for
p
in
self
.
module
.
parameters
():
if
not
p
.
ca_attr
.
is_sharded
:
p
.
ca_attr
.
shard
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
36f9a74a
...
@@ -14,7 +14,6 @@ from torch.distributed import ProcessGroup
...
@@ -14,7 +14,6 @@ from torch.distributed import ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
..sharded_model._zero3_utils
import
free_storage
from
._utils
import
has_inf_or_nan
from
._utils
import
has_inf_or_nan
...
@@ -63,8 +62,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -63,8 +62,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
hasattr
(
p
,
'ca_attr'
):
if
hasattr
(
p
,
'ca_attr'
):
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
if
dist
.
get_rank
()
==
0
:
print
(
f
'load payload
{
p
.
_name
}
{
self
.
master_params
[
p
].
shape
}
'
)
else
:
else
:
self
.
master_params
[
p
]
=
p
.
data
.
to
(
device
=
self
.
device
)
self
.
master_params
[
p
]
=
p
.
data
.
to
(
device
=
self
.
device
)
if
torch
.
is_floating_point
(
self
.
master_params
[
p
])
and
self
.
master_params
[
p
].
dtype
!=
torch
.
float
:
if
torch
.
is_floating_point
(
self
.
master_params
[
p
])
and
self
.
master_params
[
p
].
dtype
!=
torch
.
float
:
...
@@ -91,23 +88,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -91,23 +88,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
if
hasattr
(
p
,
'ca_attr'
):
if
hasattr
(
p
,
'ca_attr'
):
if
dist
.
get_rank
()
==
0
:
print
(
f
'write
{
p
.
_name
}
{
p
.
shape
}
orig_shape
{
p
.
ca_attr
.
_origin_shape
}
\
payload shape
{
p
.
ca_attr
.
_param_payload
.
shape
}
sharded
{
p
.
ca_attr
.
is_sharded
}
'
)
p
.
ca_attr
.
set_payload
(
p
.
data
)
p
.
ca_attr
.
set_payload
(
p
.
data
)
# We cannot set p.data to None directly, so we free storage
p
.
data
=
p
.
ca_attr
.
payload
()
free_storage
(
p
.
data
)
return
ret
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
loss
=
self
.
loss_scale
*
loss
loss
=
self
.
loss_scale
*
loss
self
.
optim_state
=
OptimState
.
SCALED
self
.
optim_state
=
OptimState
.
SCALED
if
self
.
model_is_sharded
:
if
self
.
model_is_sharded
:
if
dist
.
get_rank
()
==
0
:
print
(
'sharded model backward'
)
self
.
model
.
backward
(
loss
)
self
.
model
.
backward
(
loss
)
if
dist
.
get_rank
()
==
0
:
print
(
'sharded model backward done'
)
else
:
else
:
super
().
backward
(
loss
)
super
().
backward
(
loss
)
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
36f9a74a
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
typing
import
Union
,
Tuple
,
Optional
import
numpy
class
ShardedParam
(
object
):
class
ShardedParam
(
object
):
...
@@ -28,6 +29,7 @@ class ShardedParam(object):
...
@@ -28,6 +29,7 @@ class ShardedParam(object):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
is_sharded
=
False
self
.
is_sharded
=
False
self
.
device
=
device
# Hijack the data payload of param
# Hijack the data payload of param
if
isinstance
(
other
,
torch
.
nn
.
Parameter
):
if
isinstance
(
other
,
torch
.
nn
.
Parameter
):
...
@@ -50,17 +52,19 @@ class ShardedParam(object):
...
@@ -50,17 +52,19 @@ class ShardedParam(object):
self
.
_payload_numel
=
None
self
.
_payload_numel
=
None
def
payload
(
self
,
target_device
:
torch
.
device
):
def
payload
(
self
,
target_device
:
Optional
[
torch
.
device
]
=
None
):
r
"""
r
"""
get the payload and move it to target device
get the payload and move it to target device
"""
"""
return
self
.
_param_payload
.
to
(
target_device
)
if
target_device
is
not
None
:
return
self
.
_param_payload
.
to
(
target_device
)
return
self
.
_param_payload
def
set_payload
(
self
,
data
:
torch
.
Tensor
):
def
set_payload
(
self
,
data
:
torch
.
Tensor
):
r
"""
r
"""
set payload as data
set payload as data
"""
"""
assert
self
.
_param_payload
.
numel
()
==
data
.
numel
()
assert
self
.
_param_payload
.
shape
==
data
.
shape
self
.
_param_payload
.
copy_
(
data
)
self
.
_param_payload
.
copy_
(
data
)
def
shard
(
self
):
def
shard
(
self
):
...
...
tests/test_zero_data_parallel/common.py
View file @
36f9a74a
...
@@ -3,37 +3,21 @@ from functools import partial
...
@@ -3,37 +3,21 @@ from functools import partial
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
checkpoint
from
colossalai.utils
import
checkpoint
LOGGER
=
get_dist_logger
()
LOGGER
=
get_dist_logger
()
CONFIG
=
dict
(
CONFIG
=
dict
(
fp16
=
dict
(
mode
=
None
,),
fp16
=
dict
(
zero
=
dict
(
level
=
3
,
mode
=
None
,
verbose
=
False
,
),
offload_optimizer_config
=
dict
(
device
=
'cpu'
,
pin_memory
=
True
,
buffer_count
=
5
,
fast_init
=
False
),
zero
=
dict
(
offload_param_config
=
dict
(
device
=
'cpu'
,
level
=
3
,
pin_memory
=
True
,
verbose
=
False
,
buffer_count
=
5
,
offload_optimizer_config
=
dict
(
buffer_size
=
1e8
,
device
=
'cpu'
,
max_in_cpu
=
1e9
)),
pin_memory
=
True
,
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
buffer_count
=
5
,
fast_init
=
False
),
offload_param_config
=
dict
(
device
=
'cpu'
,
pin_memory
=
True
,
buffer_count
=
5
,
buffer_size
=
1e8
,
max_in_cpu
=
1e9
)
),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)
)
)
def
checkpoint_wrapper
(
module
,
enable
=
True
):
def
checkpoint_wrapper
(
module
,
enable
=
True
):
...
@@ -43,6 +27,7 @@ def checkpoint_wrapper(module, enable=True):
...
@@ -43,6 +27,7 @@ def checkpoint_wrapper(module, enable=True):
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
5
,
5
)
self
.
fc1
=
nn
.
Linear
(
5
,
5
)
...
@@ -50,13 +35,7 @@ class Net(nn.Module):
...
@@ -50,13 +35,7 @@ class Net(nn.Module):
self
.
fc3
=
nn
.
Linear
(
5
,
1
)
self
.
fc3
=
nn
.
Linear
(
5
,
1
)
if
checkpoint
:
if
checkpoint
:
self
.
fc1
=
checkpoint_wrapper
(
self
.
fc1
)
self
.
fc1
=
checkpoint_wrapper
(
self
.
fc1
)
self
.
layers
=
[
self
.
layers
=
[
self
.
fc1
,
self
.
fc2
,
self
.
fc1
,
self
.
fc2
,
self
.
fc3
]
self
.
fc1
,
self
.
fc2
,
self
.
fc1
,
self
.
fc2
,
self
.
fc3
]
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
...
@@ -111,3 +90,17 @@ def check_params_padding(model, zero_model, loose=False):
...
@@ -111,3 +90,17 @@ def check_params_padding(model, zero_model, loose=False):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
zero_p
=
zero_p
[:
p
.
size
(
0
)]
assert
p
.
dtype
==
zero_p
.
dtype
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
def
check_sharded_params_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_p
=
zero_p
.
ca_attr
.
payload
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
p
=
chunks
[
rank
]
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
36f9a74a
...
@@ -16,19 +16,18 @@ from colossalai.zero.sharded_model import ShardedModelV2
...
@@ -16,19 +16,18 @@ from colossalai.zero.sharded_model import ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
common
import
(
CONFIG
,
Net
,
check_grads
,
check_grads_padding
,
check_params
,
check_params_padding
)
from
common
import
(
CONFIG
,
Net
,
check_grads
,
check_grads_padding
,
check_params
,
check_
sharded_
params_padding
)
def
run_step
(
model
,
optimizer
,
x
,
enable_autocast
=
False
):
def
run_step
(
model
,
optimizer
,
x
,
enable_autocast
=
False
):
model
.
train
()
model
.
train
()
optimizer
.
zero_grad
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
x
)
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
=
y
.
sum
()
loss
=
loss
.
float
()
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
if
isinstance
(
model
,
ShardedModelV2
):
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
for
p
in
model
.
parameters
():
assert
p
.
ca_attr
.
is_sharded
else
:
else
:
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
...
@@ -51,7 +50,7 @@ def run_dist(rank, world_size, port):
...
@@ -51,7 +50,7 @@ def run_dist(rank, world_size, port):
run_step
(
model
,
optim
,
x
,
False
)
run_step
(
model
,
optim
,
x
,
False
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
)
check_grads_padding
(
model
,
zero_model
)
check_params_padding
(
model
,
zero_model
)
check_
sharded_
params_padding
(
model
,
zero_model
)
else
:
else
:
check_grads
(
model
,
zero_model
)
check_grads
(
model
,
zero_model
)
check_params
(
model
,
zero_model
)
check_params
(
model
,
zero_model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment