Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d0ae0f22
Commit
d0ae0f22
authored
Mar 09, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] update sharded optim v2 (#334)
parent
2b8cddd4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
115 additions
and
68 deletions
+115
-68
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+15
-7
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+48
-25
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+3
-3
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+2
-5
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+47
-28
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
d0ae0f22
...
...
@@ -102,6 +102,11 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
# In case some post bwd hook is not fired
if
self
.
shard_param
:
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
])
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
...
...
@@ -113,13 +118,12 @@ class ShardedModelV2(nn.Module):
if
not
self
.
_require_backward_grad_sync
:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
p
.
grad
.
data
=
p
.
col_attr
.
grad
# We have to make sure grad and param have the same shape
# If world size > 1, and sharded param, `.view()` may be not needed
# If world size == 1, and sharded param, `data` is a flatten tensor
# But the shape `grad` is the same as unsharded param
p
.
grad
.
data
=
p
.
col_attr
.
grad
.
view
(
p
.
col_attr
.
data
.
shape
)
p
.
col_attr
.
grad
=
None
# In case some post bwd hook is not fired
if
self
.
shard_param
:
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
])
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -180,7 +184,11 @@ class ShardedModelV2(nn.Module):
if
param
.
col_attr
.
grad
is
None
:
param
.
col_attr
.
grad
=
reduced_grad
.
data
else
:
param
.
col_attr
.
grad
.
add_
(
reduced_grad
.
data
)
# When dp size = 1
# param.col_attr.grad is local accumulated grad shard (full but flatten)
# But reduced_grad here is full grad
# We should call `view_as`
param
.
col_attr
.
grad
.
add_
(
reduced_grad
.
data
.
view_as
(
param
.
col_attr
.
grad
))
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
d0ae0f22
from
enum
import
Enum
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
import
torch
import
torch.distributed
as
dist
...
...
@@ -8,7 +8,9 @@ from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp32
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
...
...
@@ -26,7 +28,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def
__init__
(
self
,
optimizer
:
Optimizer
,
sharded_model
:
Union
[
nn
.
Module
,
ShardedModelV2
],
sharded_model
:
ShardedModelV2
,
shard_strategy
:
BaseShardStrategy
,
cpu_offload
:
bool
=
False
,
initial_scale
:
float
=
2
**
32
,
min_scale
:
float
=
1
,
...
...
@@ -37,9 +40,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale
:
int
=
2
**
32
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
super
().
__init__
(
optimizer
)
self
.
model
:
Union
[
nn
.
Module
,
ShardedModelV2
]
=
sharded_model
self
.
model
_is_s
harded
=
isinstance
(
sharded_model
,
ShardedModelV2
)
self
.
shard_strategy
=
shard_strategy
self
.
model
:
S
harded
ModelV2
=
sharded_model
self
.
device
=
torch
.
cuda
.
current_device
()
if
not
cpu_offload
else
torch
.
device
(
'cpu'
)
self
.
optim_state
:
OptimState
=
OptimState
.
UNSCALED
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
@@ -52,20 +56,25 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
)
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
self
.
device
)
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
torch
.
cuda
.
current_
device
()
)
# Store fp32 params
# Store fp32 param
shard
s
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
for
group
in
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
if
hasattr
(
p
,
'ca_attr'
):
assert
p
.
ca_attr
.
is_sharded
,
'ShardedAdam can be only used with sharded model'
self
.
master_params
[
p
]
=
p
.
ca_attr
.
payload
(
self
.
device
)
else
:
self
.
master_params
[
p
]
=
p
.
data
.
to
(
device
=
self
.
device
)
if
torch
.
is_floating_point
(
self
.
master_params
[
p
])
and
self
.
master_params
[
p
].
dtype
!=
torch
.
float
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
float
)
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
if
not
is_param_sharded
:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
])
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
col_attr
.
data
.
payload
).
to
(
self
.
device
)
if
not
is_param_sharded
:
# In this branch, there's no need to shard param
# So we gather here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
])
def
step
(
self
,
*
args
,
**
kwargs
):
# unscale grads if scaled
...
...
@@ -83,28 +92,36 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
data
=
self
.
master_params
[
p
]
# Now p.data is sharded
# So optimizer states are sharded naturally
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
# Write master param to payload
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
if
hasattr
(
p
,
'ca_attr'
):
p
.
ca_attr
.
set_payload
(
p
.
data
)
p
.
data
=
p
.
ca_attr
.
payload
()
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
if
not
is_param_sharded
:
# We use ZeRO-2 here
# The `p.col_attr.data` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
])
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
p
.
col_attr
.
data
.
copy_payload
(
p
.
data
)
if
not
is_param_sharded
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
])
p
.
data
=
p
.
col_attr
.
data
.
payload
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
loss
=
self
.
loss_scale
*
loss
self
.
optim_state
=
OptimState
.
SCALED
if
self
.
model_is_sharded
:
self
.
model
.
backward
(
loss
)
else
:
super
().
backward
(
loss
)
self
.
model
.
backward
(
loss
)
def
backward_by_grad
(
self
,
tensor
:
Tensor
,
grad
:
Tensor
)
->
None
:
if
self
.
model_is_sharded
:
self
.
model
.
backward_by_grad
(
tensor
,
grad
)
else
:
super
().
backward_by_grad
(
tensor
,
grad
)
self
.
model
.
backward_by_grad
(
tensor
,
grad
)
def
clip_grad_norm
(
self
,
model
:
nn
.
Module
,
max_norm
:
float
):
if
self
.
optim_state
==
OptimState
.
SCALED
:
...
...
@@ -113,7 +130,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
@
property
def
loss_scale
(
self
):
return
self
.
grad_scaler
.
scale
return
self
.
grad_scaler
.
scale
.
item
()
def
_check_overflow
(
self
):
# clear previous overflow record
...
...
@@ -141,3 +158,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
div_
(
self
.
loss_scale
)
self
.
optim_state
=
OptimState
.
UNSCALED
def
zero_grad
(
self
,
*
args
,
**
kwargs
):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
tests/test_zero_data_parallel/common.py
View file @
d0ae0f22
...
...
@@ -95,12 +95,12 @@ def check_params_padding(model, zero_model, loose=False):
def
check_sharded_params_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_p
=
zero_p
.
c
a
_attr
.
payload
(
p
.
device
)
zero_p
=
zero_p
.
c
ol
_attr
.
data
.
payload
.
to
(
p
.
device
)
.
float
()
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
p
=
chunks
[
rank
]
p
=
chunks
[
rank
]
.
float
()
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
,
f
'
{
p
}
vs
{
zero_p
}
'
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
d0ae0f22
...
...
@@ -17,7 +17,7 @@ from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
from
common
import
CONFIG
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
...
...
@@ -69,10 +69,7 @@ def run_dist(rank, world_size, port):
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
d0ae0f22
...
...
@@ -9,22 +9,23 @@ import pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Adam
from
common
import
(
CONFIG
,
Net
,
check_grads
,
check_grads_padding
,
check_params
,
check_sharded_params_padding
)
from
common
import
CONFIG
,
check_sharded_params_padding
def
run_step
(
model
,
optimizer
,
x
,
enable_autocast
=
False
):
def
run_step
(
model
,
optimizer
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
optimizer
.
zero_grad
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
x
)
loss
=
y
.
sum
(
)
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
optimizer
.
backward
(
loss
)
...
...
@@ -33,35 +34,53 @@ def run_step(model, optimizer, x, enable_autocast=False):
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
def
run_step_no_criterion
(
model
,
optimizer
,
data
,
label
,
enable_autocast
=
False
):
model
.
train
()
optimizer
.
zero_grad
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
loss
=
model
(
data
,
label
)
if
isinstance
(
model
,
ShardedModelV2
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
optimizer
.
step
()
model
=
Net
(
checkpoint
=
True
).
cuda
()
zero_model
=
copy
.
deepcopy
(
model
)
zero_model
=
ShardedModelV2
(
zero_model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
n
,
p
in
zero_model
.
named_parameters
():
p
.
_name
=
n
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
Adam
(
zero_model
.
parameters
(),
lr
=
1e-3
),
zero_model
)
for
_
in
range
(
2
):
x
=
torch
.
rand
(
2
,
5
).
cuda
()
run_step
(
zero_model
,
sharded_optim
,
x
,
False
)
run_step
(
model
,
optim
,
x
,
False
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
)
check_sharded_params_padding
(
model
,
zero_model
)
else
:
check_grads
(
model
,
zero_model
)
check_params
(
model
,
zero_model
)
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
Adam
(
zero_model
.
parameters
(),
lr
=
1e-3
),
zero_model
,
shard_strategy
,
initial_scale
=
2
**
5
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
if
criterion
is
None
:
run_step_no_criterion
(
model
,
optim
,
data
,
label
,
False
)
run_step_no_criterion
(
zero_model
,
sharded_optim
,
data
,
label
,
False
)
else
:
run_step
(
model
,
optim
,
data
,
label
,
criterion
,
False
)
run_step
(
zero_model
,
sharded_optim
,
data
,
label
,
criterion
,
False
)
check_sharded_params_padding
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
skip
def
test_sharded_optim_v2
():
world_size
=
2
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
def
test_sharded_optim_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_v2
()
test_sharded_optim_v2
(
world_size
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment