Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7aef75ca
Commit
7aef75ca
authored
Mar 02, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)
parent
9afb5c8b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
303 additions
and
73 deletions
+303
-73
colossalai/engine/ophooks/__init__.py
colossalai/engine/ophooks/__init__.py
+6
-3
colossalai/engine/ophooks/_shard_grad_ophook.py
colossalai/engine/ophooks/_shard_grad_ophook.py
+31
-0
colossalai/zero/shard_param/shard_param.py
colossalai/zero/shard_param/shard_param.py
+9
-2
colossalai/zero/sharded_model/reduce_scatter.py
colossalai/zero/sharded_model/reduce_scatter.py
+0
-4
colossalai/zero/sharded_model/sharded_grad.py
colossalai/zero/sharded_model/sharded_grad.py
+85
-0
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+119
-21
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+35
-4
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+15
-10
tests/test_zero_data_parallel/test_zero_dev_3_mp4.py
tests/test_zero_data_parallel/test_zero_dev_3_mp4.py
+3
-29
No files found.
colossalai/engine/ophooks/__init__.py
View file @
7aef75ca
from
typing
import
List
import
torch
from
._base_ophook
import
BaseOpHook
from
._memtracer_ophook
import
MemTracerOpHook
from
._shard_grad_ophook
import
ShardGradHook
from
._shard_param_ophook
import
ShardParamHook
import
torch
from
typing
import
List
all
=
[
"BaseOpHook"
,
"MemTracerOpHook"
,
"register_ophooks_recursively"
,
"ShardParamHook"
]
all
=
[
"BaseOpHook"
,
"MemTracerOpHook"
,
"register_ophooks_recursively"
,
"ShardParamHook"
,
"ShardGradHook"
]
# apply torch.autograd.Function that calls a backward_function to tensors in output
...
...
colossalai/engine/ophooks/_shard_grad_ophook.py
0 → 100644
View file @
7aef75ca
import
torch
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
@
OPHOOKS
.
register_module
class
ShardGradHook
(
BaseOpHook
):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def
__init__
(
self
):
super
().
__init__
()
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'_sharded_grad'
)
param
.
_sharded_grad
.
setup
()
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
pass
def
post_iter
(
self
):
pass
colossalai/zero/shard_param/shard_param.py
View file @
7aef75ca
from
enum
import
Enum
import
torch
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
import
torch.distributed
as
dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
import
torch.distributed
as
dist
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
class
TensorType
(
Enum
):
...
...
@@ -27,9 +28,11 @@ class ShardParam(object):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
local_rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
_param_payload
=
param
.
data
if
tensor_type
==
TensorType
.
DATA
else
param
.
grad
self
.
_payload_shape
=
None
self
.
_payload_numel
=
None
self
.
_origin_shape
=
param
.
shape
self
.
_origin_numel
=
param
.
numel
()
self
.
_origin_dtype
=
param
.
dtype
self
.
is_sharded
=
False
def
payload
(
self
,
target_device
:
torch
.
device
):
...
...
@@ -65,3 +68,7 @@ class ShardParam(object):
async_op
=
False
)
self
.
_param_payload
=
torch
.
narrow
(
torch
.
cat
(
buffer_list
),
0
,
0
,
self
.
_origin_numel
).
view
(
self
.
_origin_shape
)
self
.
is_sharded
=
False
@
property
def
origin_dtype
(
self
):
return
self
.
_origin_dtype
colossalai/zero/sharded_model/reduce_scatter.py
View file @
7aef75ca
...
...
@@ -190,10 +190,6 @@ class ReduceScatterBucketer:
return
int
(
bucket_size
//
num_shards
)
def
_get_bucket
(
self
,
tensor
:
Tensor
,
group
:
ProcessGroup
)
->
Bucket
:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key
=
(
tensor
.
dtype
,
tensor
.
device
,
group
)
if
key
not
in
self
.
buckets
:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
...
...
colossalai/zero/sharded_model/sharded_grad.py
0 → 100644
View file @
7aef75ca
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
class
ShardedGradient
:
def
__init__
(
self
,
param
:
Parameter
,
sharded_module
:
nn
.
Module
,
offload_config
:
Optional
[
dict
]
=
None
)
->
None
:
assert
hasattr
(
param
,
'ca_attr'
)
and
param
.
ca_attr
.
is_sharded
,
'ShardedGradient can only be initialized with sharded parameter'
self
.
param
=
param
self
.
sharded_module
=
sharded_module
self
.
offload_config
=
offload_config
self
.
_cpu_offload
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# _gpu_grad is either sharded or not
# all saved grads are fp32
self
.
_gpu_grad
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_cpu_grad
:
Optional
[
torch
.
Tensor
]
=
None
if
self
.
_cpu_offload
:
# this buffer will be held and reused every iteration
self
.
_cpu_grad
=
torch
.
zeros
(
param
.
ca_attr
.
payload
(
'cpu'
),
dtype
=
torch
.
float
).
pin_memory
()
@
torch
.
no_grad
()
def
setup
(
self
)
->
None
:
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
:raises AssertionError: Raise if grad shape is wrong
"""
if
self
.
sharded_module
.
_require_backward_grad_sync
and
self
.
param
.
grad
is
not
None
:
if
self
.
param
.
grad
.
device
!=
self
.
param
.
data
.
device
:
# TODO: offload?
raise
RuntimeError
(
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}'
)
else
:
self
.
_gpu_grad
=
self
.
param
.
grad
.
data
self
.
param
.
grad
=
None
def
reduce_scatter_callback
(
self
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
:param reduced_grad: the reduced grad
:type reduced_grad: torch.Tensor
"""
# Make sure we store fp32 grad
if
torch
.
is_floating_point
(
reduced_grad
)
and
reduced_grad
.
dtype
!=
torch
.
float
:
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
torch
.
float
)
if
self
.
_gpu_grad
is
None
:
self
.
_gpu_grad
=
reduced_grad
.
data
else
:
self
.
_gpu_grad
+=
reduced_grad
.
data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if
self
.
_cpu_offload
:
self
.
_cpu_grad
.
copy_
(
self
.
_gpu_grad
.
data
,
non_blocking
=
True
)
# Don't let this memory get reused until after the transfer.
self
.
_gpu_grad
.
data
.
record_stream
(
torch
.
cuda
.
current_stream
())
@
torch
.
no_grad
()
def
write_back
(
self
)
->
None
:
"""This function will be called in final backward hook
"""
if
self
.
_cpu_grad
is
not
None
:
assert
self
.
param
.
device
==
torch
.
device
(
'cpu'
),
f
'Incorrect param device, expected CPU, got
{
self
.
param
.
device
}
'
self
.
param
.
grad
.
data
=
self
.
_cpu_grad
elif
self
.
_gpu_grad
is
not
None
:
assert
self
.
param
.
device
==
self
.
_gpu_grad
.
device
,
f
'Incorrect _gpu_grad device, param on
{
self
.
param
.
device
}
but _gpu_grad on
{
self
.
_gpu_grad
.
device
}
'
self
.
param
.
grad
.
data
=
self
.
_gpu_grad
else
:
raise
RuntimeError
(
'No grad to write back'
)
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
# They should be released here
self
.
_gpu_grad
=
None
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7aef75ca
import
contextlib
import
copy
import
functools
import
os
import
traceback
from
collections
import
OrderedDict
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Union
)
from
typing
import
Any
,
Optional
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine.ophooks
import
(
ShardGradHook
,
ShardParamHook
,
register_ophooks_recursively
)
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
torch.distributed
import
ProcessGroup
from
colossalai.engine.ophooks
import
register_ophooks_recursively
,
BaseOpHook
,
ShardParamHook
from
colossalai.zero.shard_param
import
ShardParam
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.sharded_grad
import
ShardedGradient
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
chunk_and_pad
,
get_gradient_predivide_factor
class
ShardedModelV2
(
nn
.
Module
):
def
__init__
(
self
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
):
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
reshard_after_forward
:
bool
=
True
,
mixed_precision
:
bool
=
False
,
fp32_reduce_scatter
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
):
r
"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
...
...
@@ -45,19 +51,111 @@ class ShardedModelV2(nn.Module):
for
_
,
param
in
self
.
module
.
named_parameters
():
param
.
ca_attr
=
ShardParam
(
param
)
param
.
ca_attr
.
shard
()
param
.
_sharded_grad
=
ShardedGradient
(
param
,
self
,
offload_config
)
# Register hooks
register_ophooks_recursively
(
self
.
module
,
[
ShardParamHook
()])
register_ophooks_recursively
(
self
.
module
,
[
ShardParamHook
(),
ShardGradHook
()])
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
self
.
reshard_after_forward
=
reshard_after_forward
self
.
mixed_precision
=
mixed_precision
self
.
fp32_reduce_scatter
=
fp32_reduce_scatter
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self
.
gradient_predivide_factor
:
float
=
gradient_predivide_factor
if
gradient_predivide_factor
is
not
None
else
\
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
comm_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
self
.
reducer
=
ReduceScatterBucketer
(
reduce_scatter_bucket_size_mb
)
self
.
_require_backward_grad_sync
:
bool
=
True
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
def
backward
(
self
,
loss
):
if
self
.
loss_scaler
:
self
.
loss_scaler
.
backward
(
loss
)
else
:
loss
.
backward
()
\ No newline at end of file
loss
.
backward
()
self
.
_final_backward_hook
()
@
torch
.
no_grad
()
def
_final_backward_hook
(
self
)
->
None
:
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
self
.
reducer
.
flush
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
if
self
.
_cpu_offload
:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
p
.
_sharded_grad
.
write_back
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param._sharded_grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if
grad
is
None
:
return
assert
not
grad
.
requires_grad
,
'ShardedModel only works with gradients that don
\'
t require gradients'
if
not
self
.
_require_backward_grad_sync
:
return
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
new_grad
=
grad
.
clone
()
if
self
.
mixed_precision
and
self
.
fp32_reduce_scatter
:
new_grad
.
data
=
new_grad
.
data
.
to
(
param
.
dtype
)
if
self
.
gradient_predivide_factor
>
1.0
:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad
.
data
.
div_
(
self
.
gradient_predivide_factor
)
orig_grad_data
=
new_grad
.
data
if
self
.
world_size
>
1
:
grad_chunks
=
chunk_and_pad
(
orig_grad_data
,
self
.
reduce_scatter_process_group
.
size
())
self
.
reducer
.
reduce_scatter_async
(
grad_chunks
,
group
=
self
.
reduce_scatter_process_group
,
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
))
else
:
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if
self
.
mixed_precision
:
orig_param_grad_data
=
reduced_grad
.
data
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
dtype
=
param
.
ca_attr
.
origin_dtype
)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
param
.
_sharded_grad
.
reduce_scatter_callback
(
reduced_grad
)
tests/test_zero_data_parallel/common.py
View file @
7aef75ca
from
functools
import
partial
from
operator
import
imod
from
colossalai.utils
import
checkpoint
import
torch.nn
as
nn
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.utils
import
checkpoint
LOGGER
=
get_dist_logger
()
...
...
@@ -34,6 +35,7 @@ CONFIG = dict(
)
)
def
checkpoint_wrapper
(
module
,
enable
=
True
):
if
enable
:
module
.
forward
=
partial
(
checkpoint
,
module
.
forward
)
...
...
@@ -61,6 +63,7 @@ class Net(nn.Module):
x
=
layer
(
x
)
return
x
def
allclose
(
tensor_a
:
torch
.
Tensor
,
tensor_b
:
torch
.
Tensor
,
loose
=
False
)
->
bool
:
if
loose
:
return
torch
.
allclose
(
tensor_a
,
tensor_b
,
atol
=
1e-3
,
rtol
=
1e-3
)
...
...
@@ -72,7 +75,8 @@ def check_grads(model, zero_model, loose=False):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
assert
p
.
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
p
.
grad
,
zero_grad
,
loose
=
loose
)
LOGGER
.
info
(
torch
.
sum
(
p
.
grad
-
zero_grad
))
LOGGER
.
info
(
torch
.
sum
(
p
.
grad
-
zero_grad
))
def
check_params
(
model
,
zero_model
,
loose
=
False
):
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
...
...
@@ -80,3 +84,30 @@ def check_params(model, zero_model, loose=False):
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
def
check_grads_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
grad
=
chunks
[
rank
]
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_p
=
zero_p
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
p
=
chunks
[
rank
]
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
7aef75ca
...
...
@@ -3,19 +3,18 @@
import
copy
from
functools
import
partial
from
operator
import
mod
from
pyexpat
import
model
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
,
check_grads
from
common
import
CONFIG
,
Net
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
x
,
enable_autocast
=
False
):
...
...
@@ -24,8 +23,11 @@ def run_fwd_bwd(model, x, enable_autocast=False):
y
=
model
(
x
)
loss
=
y
.
sum
()
loss
=
loss
.
float
()
loss
.
backward
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
...
...
@@ -34,7 +36,7 @@ def run_dist(rank, world_size, port):
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
Net
(
checkpoint
=
True
).
cuda
()
zero_model
=
copy
.
deepcopy
(
model
)
zero_model
=
ShardedModelV2
(
zero_model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
...
...
@@ -43,7 +45,10 @@ def run_dist(rank, world_size, port):
x
=
torch
.
rand
(
2
,
5
).
cuda
()
run_fwd_bwd
(
zero_model
,
x
,
False
)
run_fwd_bwd
(
model
,
x
,
False
)
check_grads
(
model
,
zero_model
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
)
else
:
check_grads
(
model
,
zero_model
)
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_zero_dev_3_mp4.py
View file @
7aef75ca
...
...
@@ -14,7 +14,9 @@ from colossalai.logging import disable_existing_loggers
from
colossalai.utils
import
checkpoint
,
free_port
from
colossalai.zero.sharded_model
import
ShardedModel
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
Net
,
allclose
from
common
import
Net
,
check_grads_padding
,
check_params_padding
def
run_step
(
model
,
optimizer
,
x
,
enable_autocast
=
False
):
model
.
train
()
...
...
@@ -26,34 +28,6 @@ def run_step(model, optimizer, x, enable_autocast=False):
loss
.
backward
()
optimizer
.
step
()
def
check_grads_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
4
)
if
rank
>=
len
(
chunks
):
continue
grad
=
chunks
[
rank
]
if
zero_p
.
zero_shard_padding
>
0
:
zero_grad
=
zero_grad
[:
-
zero_p
.
zero_shard_padding
]
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_shard_padding
=
zero_p
.
zero_shard_padding
zero_p
=
zero_p
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
).
chunk
(
4
)
if
rank
>=
len
(
chunks
):
continue
p
=
chunks
[
rank
]
if
zero_shard_padding
>
0
:
zero_p
=
zero_p
[:
-
zero_shard_padding
]
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
)
def
decode_booleans
(
intval
,
bits
):
res
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment