Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
13886716
Commit
13886716
authored
Mar 08, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] Update sharded model v2 using sharded param v2 (#323)
parent
799d105b
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
404 additions
and
203 deletions
+404
-203
colossalai/engine/ophooks/__init__.py
colossalai/engine/ophooks/__init__.py
+9
-11
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+58
-0
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+9
-8
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+4
-5
colossalai/zero/sharded_model/_zero3_utils.py
colossalai/zero/sharded_model/_zero3_utils.py
+9
-17
colossalai/zero/sharded_model/sharded_model.py
colossalai/zero/sharded_model/sharded_model.py
+64
-60
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+68
-43
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+16
-17
tests/__init__.py
tests/__init__.py
+0
-0
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+6
-6
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+10
-8
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+27
-18
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+7
-9
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
+73
-0
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+1
-1
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+43
-0
No files found.
colossalai/engine/ophooks/__init__.py
View file @
13886716
...
...
@@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
if
type
(
outputs
)
is
tuple
:
touched_outputs
=
[]
for
output
in
outputs
:
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
output
)
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
output
)
touched_outputs
.
append
(
touched_output
)
return
tuple
(
touched_outputs
)
elif
type
(
outputs
)
is
torch
.
Tensor
:
...
...
@@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
class
PreBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
outputs
):
ctx
.
module
=
module
...
...
@@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
class
PostBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
output
):
ctx
.
module
=
module
...
...
@@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
return
(
None
,
None
)
+
args
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
ophook_list
:
List
[
BaseOpHook
]
=
None
,
name
:
str
=
""
):
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
ophook_list
:
List
[
BaseOpHook
]
=
None
,
name
:
str
=
""
):
r
"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
has_children
=
False
...
...
@@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
# Early return on modules with no parameters or buffers that
# are not in their children.
if
(
len
(
list
(
module
.
named_parameters
(
recurse
=
False
)))
==
0
and
len
(
list
(
module
.
named_buffers
(
recurse
=
False
)))
==
0
):
if
(
len
(
list
(
module
.
named_parameters
(
recurse
=
False
)))
==
0
and
len
(
list
(
module
.
named_buffers
(
recurse
=
False
)))
==
0
):
return
# return if the module has not childern.
...
...
@@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
hook
.
post_fwd_exec
(
submodule
,
*
args
)
def
_pre_backward_module_hook
(
submodule
,
inputs
,
output
):
def
_run_before_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_bwd_exec
(
submodule
,
inputs
,
output
)
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
_run_before_backward_function
,
output
)
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
_run_before_backward_function
,
output
)
def
_post_backward_module_hook
(
submodule
,
inputs
):
def
_run_after_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_bwd_exec
(
submodule
,
inputs
)
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
_run_after_backward_function
,
inputs
)
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
_run_after_backward_function
,
inputs
)
module
.
register_forward_pre_hook
(
_pre_forward_module_hook
)
module
.
register_forward_hook
(
_post_forward_module_hook
)
...
...
colossalai/engine/ophooks/zero_hook.py
0 → 100644
View file @
13886716
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
@
OPHOOKS
.
register_module
class
ZeroHook
(
BaseOpHook
):
"""
A hook to process sharded param for ZeRO method.
"""
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
param
.
data
=
param
.
col_attr
.
data
.
payload
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
param
.
data
=
param
.
col_attr
.
data
.
payload
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
grad
is
None
param
.
col_attr
.
grad
=
param
.
grad
.
data
param
.
grad
=
None
else
:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert
param
.
grad
.
shape
==
param
.
col_attr
.
data
.
origin_shape
param
.
col_attr
.
bwd_count
+=
1
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
def
pre_iter
(
self
):
pass
def
post_iter
(
self
):
pass
colossalai/zero/init_ctx/init_context.py
View file @
13886716
import
functools
from
colossalai.utils.cuda
import
get_current_device
import
torch
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
...
...
@@ -103,8 +104,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'c
a
_attr'
)
param
.
c
a
_attr
.
remove_torch_payload
()
assert
hasattr
(
param
,
'c
ol
_attr'
)
param
.
c
ol
_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
...
...
@@ -113,7 +114,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
for
param
in
module
.
parameters
():
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'c
a
_attr'
):
if
hasattr
(
param
,
'c
ol
_attr'
):
continue
if
self
.
convert_cuda
:
...
...
@@ -127,11 +128,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
c
a
_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
c
ol
_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
a
_attr
.
_data_sharded_tensor
])
if
param
.
c
a
_attr
.
grad
and
self
.
shard_grad
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
a
_attr
.
_grad_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
ol
_attr
.
_data_sharded_tensor
])
if
param
.
c
ol
_attr
.
grad
and
self
.
shard_grad
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
ol
_attr
.
_grad_sharded_tensor
])
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
13886716
import
torch
import
torch.distributed
as
dist
from
typing
import
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
class
TensorShardStrategy
(
BaseShardStrategy
):
...
...
@@ -38,7 +37,7 @@ class TensorShardStrategy(BaseShardStrategy):
if
i
==
self
.
local_rank
:
buffer_list
.
append
(
t
.
payload
.
cuda
())
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
).
cuda
())
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
).
cuda
())
torch
.
distributed
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
...
...
colossalai/zero/sharded_model/_zero3_utils.py
View file @
13886716
from
collections
import
OrderedDict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
...
...
@@ -42,27 +41,21 @@ def free_storage(data: torch.Tensor) -> None:
@
torch
.
no_grad
()
def
alloc_storage
(
data
:
torch
.
Tensor
,
size
:
torch
.
Size
)
->
None
:
"""Allocate storage for a tensor."""
if
data
.
storage
().
size
()
==
size
.
numel
():
# no need to reallocate
if
data
.
storage
().
size
()
==
size
.
numel
():
# no need to reallocate
return
assert
data
.
storage
().
size
()
==
0
data
.
storage
().
resize_
(
size
.
numel
())
def
cast_trensor_to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
tensor
.
dtype
is
torch
.
float32
:
out
=
tensor
.
half
()
if
tensor
.
is_leaf
:
out
.
requires_grad
=
tensor
.
requires_grad
return
out
def
cast_tensor_to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
is_floating_point
(
tensor
)
and
tensor
.
dtype
is
torch
.
float32
:
return
tensor
.
half
()
return
tensor
def
cast_trensor_to_fp32
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
tensor
.
dtype
is
torch
.
float16
:
out
=
tensor
.
float
()
if
tensor
.
is_leaf
:
out
.
requires_grad
=
tensor
.
requires_grad
return
out
def
cast_tensor_to_fp32
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
torch
.
is_floating_point
(
tensor
)
and
tensor
.
dtype
is
torch
.
float16
:
return
tensor
.
float
()
return
tensor
...
...
@@ -102,9 +95,8 @@ def assert_in_engine(cond: Any, s: Any) -> None:
raise
AssertionError
def
replace_state_dict_prefix
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
def
replace_state_dict_prefix
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
...
...
colossalai/zero/sharded_model/sharded_model.py
View file @
13886716
...
...
@@ -5,8 +5,7 @@ import os
import
traceback
from
collections
import
OrderedDict
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Union
)
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Union
)
import
torch
import
torch.distributed
as
dist
...
...
@@ -15,16 +14,14 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
.param_manager
import
Zero3ParameterManager
from
torch.autograd
import
Variable
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
(
apply_to_tensors
,
assert_in_engine
,
cast_float_arguments
,
cast_trensor_to_fp16
,
cast_trensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
get_shard
,
from
._zero3_utils
import
(
apply_to_tensors
,
assert_in_engine
,
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
get_shard
,
replace_state_dict_prefix
)
from
.param_manager
import
Zero3ParameterManager
from
.reduce_scatter
import
ReduceScatterBucketer
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
...
...
@@ -41,11 +38,13 @@ class TrainingState(Enum):
POST_BACKWARD
=
auto
()
GATHER_FULL_PARAMS
=
auto
()
# TODO: Add clip_grad_norm_
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
class
ShardedModel
(
nn
.
Module
):
def
__init__
(
self
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
...
...
@@ -96,8 +95,10 @@ class ShardedModel(nn.Module):
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self
.
gradient_predivide_factor
:
float
=
gradient_predivide_factor
if
gradient_predivide_factor
is
not
None
else
\
# However, if you set gradient_predivide_factor to None
# we will set gradient_predivide_factor to a value >= 1.0 automatically
self
.
gradient_predivide_factor
:
float
=
gradient_predivide_factor
if
\
gradient_predivide_factor
is
not
None
else
\
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
...
...
@@ -111,8 +112,12 @@ class ShardedModel(nn.Module):
self
.
module
=
module
self
.
param_manager
=
Zero3ParameterManager
(
module
,
process_group
=
self
.
process_group
,
mixed_precision
=
self
.
mixed_precision
,
flatten_parameters
=
flatten_parameters
,
compute_dtype
=
self
.
compute_dtype
,
compute_device
=
self
.
compute_device
,
self
.
param_manager
=
Zero3ParameterManager
(
module
,
process_group
=
self
.
process_group
,
mixed_precision
=
self
.
mixed_precision
,
flatten_parameters
=
flatten_parameters
,
compute_dtype
=
self
.
compute_dtype
,
compute_device
=
self
.
compute_device
,
offload_config
=
offload_config
)
self
.
_reset_lazy_init_info
()
...
...
@@ -145,13 +150,13 @@ class ShardedModel(nn.Module):
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if
self
.
_is_root
and
self
.
mixed_precision
:
args
,
kwargs
=
cast_float_arguments
(
cast_t
r
ensor_to_fp16
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if
self
.
force_input_to_fp32
and
not
self
.
mixed_precision
:
args
,
kwargs
=
cast_float_arguments
(
cast_t
r
ensor_to_fp32
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp32
,
*
args
,
**
kwargs
)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
...
...
@@ -201,10 +206,9 @@ class ShardedModel(nn.Module):
input_tensor
=
torch
.
ones
(
1
).
to
(
self
.
compute_device
)
output
=
list
(
torch
.
zeros
(
self
.
world_size
).
to
(
self
.
compute_device
).
chunk
(
self
.
world_size
))
dist
.
all_gather
(
output
,
input_tensor
,
group
=
self
.
process_group
)
assert
torch
.
cat
(
output
).
sum
()
==
float
(
self
.
world_size
),
(
f
"found
{
torch
.
cat
(
output
).
sum
()
}
devices in process group but "
f
"world_size=
{
self
.
world_size
}
. Check torch.cuda.set_device is called properly"
)
assert
torch
.
cat
(
output
).
sum
()
==
float
(
self
.
world_size
),
(
f
"found
{
torch
.
cat
(
output
).
sum
()
}
devices in process group but "
f
"world_size=
{
self
.
world_size
}
. Check torch.cuda.set_device is called properly"
)
def
_reset_lazy_init_info
(
self
)
->
None
:
self
.
_is_root
:
Optional
[
bool
]
=
None
...
...
@@ -277,9 +281,10 @@ class ShardedModel(nn.Module):
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m
.
no_broadcast_optim_state
=
m
.
no_broadcast_optim_state
or
(
(
m
.
world_size
==
1
)
and
(
m
.
world_size
<
self
.
world_size
)
and
(
m
.
process_group
!=
self
.
process_group
)
)
m
.
no_broadcast_optim_state
=
m
.
no_broadcast_optim_state
or
\
((
m
.
world_size
==
1
)
and
(
m
.
world_size
<
self
.
world_size
)
and
(
m
.
process_group
!=
self
.
process_group
))
def
_setup_streams
(
self
)
->
None
:
"""Create streams to overlap data transfer and computation."""
...
...
@@ -330,9 +335,10 @@ class ShardedModel(nn.Module):
else
:
self
.
_streams
[
"all_gather"
].
wait_stream
(
torch
.
cuda
.
current_stream
())
def
_cast_buffers
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
memo
:
Optional
[
Set
]
=
None
)
->
None
:
def
_cast_buffers
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
memo
:
Optional
[
Set
]
=
None
)
->
None
:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
...
...
@@ -398,7 +404,7 @@ class ShardedModel(nn.Module):
outputs: new outputs with hooks registered if they requires gradient.
"""
if
not
torch
.
is_grad_enabled
():
return
outputs
# don't register hooks if grad isn't enabled
return
outputs
# don't register hooks if grad isn't enabled
if
self
.
_is_root
:
# This actually means that only root instance has
...
...
@@ -523,7 +529,7 @@ class ShardedModel(nn.Module):
a new hook, which is needed for a new forward pass.
"""
if
not
torch
.
is_grad_enabled
():
return
# don't register grad hooks if grad isn't enabled
return
# don't register grad hooks if grad isn't enabled
for
p
in
self
.
params
:
if
p
.
requires_grad
:
if
hasattr
(
p
,
"zero_shard_bwd_hook"
):
...
...
@@ -612,7 +618,8 @@ class ShardedModel(nn.Module):
if
param
.
zero_is_sharded
:
assert
self
.
_reducer
is
not
None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times it's possible that multiple
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times
# it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
...
...
@@ -628,9 +635,9 @@ class ShardedModel(nn.Module):
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
)
grad_chunks
=
chunk_and_pad
(
orig_grad_data
,
self
.
reduce_scatter_process_group
.
size
())
self
.
_reducer
.
reduce_scatter_async
(
grad_chunks
,
group
=
self
.
reduce_scatter_process_group
,
callback_fn
=
callback_fn
)
self
.
_reducer
.
reduce_scatter_async
(
grad_chunks
,
group
=
self
.
reduce_scatter_process_group
,
callback_fn
=
callback_fn
)
else
:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
...
...
@@ -667,8 +674,9 @@ class ShardedModel(nn.Module):
param
.
zero_saved_grad_shard
=
reduced_grad
.
data
else
:
assert
(
param
.
zero_saved_grad_shard
.
shape
==
reduced_grad
.
shape
),
f
"
{
param
.
zero_saved_grad_shard
.
shape
}
vs
{
reduced_grad
.
shape
}
"
param
.
zero_saved_grad_shard
.
shape
==
reduced_grad
.
shape
),
f
"
{
param
.
zero_saved_grad_shard
.
shape
}
\
vs
{
reduced_grad
.
shape
}
"
param
.
zero_saved_grad_shard
.
data
+=
reduced_grad
.
data
reduced_grad
=
param
.
zero_saved_grad_shard
.
data
else
:
...
...
@@ -717,7 +725,7 @@ class ShardedModel(nn.Module):
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
assert_in_engine
(
self
.
_reducer
is
not
None
,
"FinalBackwardHook: reducer is None"
)
assert
self
.
_reducer
is
not
None
# make mypy happy
assert
self
.
_reducer
is
not
None
# make mypy happy
self
.
_reducer
.
flush
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"post_backward"
])
if
self
.
_cpu_offload
:
...
...
@@ -753,7 +761,8 @@ class ShardedModel(nn.Module):
elif
hasattr
(
p
,
"zero_saved_grad_shard"
):
assert_in_engine
(
p
.
device
==
p
.
zero_saved_grad_shard
.
device
,
f
"FinalBackwardHook: incorrect saved_grad_shard device
{
p
.
device
}
vs
{
p
.
zero_saved_grad_shard
.
device
}
"
,
f
"FinalBackwardHook: incorrect saved_grad_shard device
\
{
p
.
device
}
vs
{
p
.
zero_saved_grad_shard
.
device
}
"
,
)
p
.
grad
=
p
.
zero_saved_grad_shard
elif
hasattr
(
p
,
'zero_saved_grad'
):
...
...
@@ -765,7 +774,7 @@ class ShardedModel(nn.Module):
delattr
(
p
,
"zero_saved_grad"
)
# Update root and nested ShardedModel's hooks and flags.
for
m
in
self
.
modules
():
# includes self
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
ShardedModel
):
_finalize_parameters
(
m
)
m
.
_pre_backward_hook_has_run
=
False
...
...
@@ -796,7 +805,7 @@ class ShardedModel(nn.Module):
self
.
_output_pre_backward_hook_registered
is
not
None
,
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None"
,
)
assert
self
.
_output_pre_backward_hook_registered
is
not
None
# make mypy happy
assert
self
.
_output_pre_backward_hook_registered
is
not
None
# make mypy happy
self
.
_output_pre_backward_hook_registered
.
clear
()
@
contextlib
.
contextmanager
...
...
@@ -908,9 +917,9 @@ class ShardedModel(nn.Module):
state
[
"is_sharded"
]
=
[
p
.
zero_is_sharded
for
p
in
self
.
params
]
state
[
"orig_sizes"
]
=
[
p
.
zero_orig_size
for
p
in
self
.
params
]
if
state
[
"process_group"
]
is
not
None
:
state
[
"process_group"
]
=
"MISSING"
# process_group isn't pickleable
state
[
"process_group"
]
=
"MISSING"
# process_group isn't pickleable
if
state
[
"process_group_reduce_scatter"
]
is
not
None
:
state
[
"process_group_reduce_scatter"
]
=
"MISSING"
# process_group_reduce_scatter isn't pickleable
state
[
"process_group_reduce_scatter"
]
=
"MISSING"
# process_group_reduce_scatter isn't pickleable
self
.
_reset_lazy_init_info
()
return
state
...
...
@@ -920,7 +929,7 @@ class ShardedModel(nn.Module):
def
fixup
(
p
:
Parameter
,
is_sharded
:
bool
,
size
:
torch
.
Size
)
->
Parameter
:
assert
isinstance
(
p
,
Parameter
)
p
.
data
=
p
.
data
.
clone
()
# move tensors out of shared memory
p
.
data
=
p
.
data
.
clone
()
# move tensors out of shared memory
p
.
zero_is_sharded
=
is_sharded
p
.
zero_orig_size
=
size
return
p
...
...
@@ -958,7 +967,7 @@ class ShardedModel(nn.Module):
# This instance may wrap other ShardedModel instances and we
# need to set all of them to accumulate gradients.
old_flags
=
[]
for
m
in
self
.
modules
():
# includes self
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
ShardedModel
):
old_flags
.
append
((
m
,
m
.
_require_backward_grad_sync
))
m
.
_require_backward_grad_sync
=
False
...
...
@@ -986,22 +995,18 @@ class ShardedModel(nn.Module):
raise
ValueError
(
msg
)
def
extra_repr
(
self
)
->
str
:
repr
=
(
f
"world_size=
{
self
.
world_size
}
, "
f
"mixed_precision=
{
self
.
mixed_precision
}
, "
)
repr
=
(
f
"world_size=
{
self
.
world_size
}
, "
f
"mixed_precision=
{
self
.
mixed_precision
}
, "
)
if
self
.
verbose
:
repr
=
(
f
"rank=
{
self
.
rank
}
, "
+
repr
+
f
"reshard_after_forward=
{
self
.
reshard_after_forward
}
, "
f
"compute_dtype=
{
self
.
compute_dtype
}
, "
f
"buffer_dtype=
{
self
.
buffer_dtype
}
, "
f
"fp32_reduce_scatter=
{
self
.
fp32_reduce_scatter
}
, "
f
"compute_device=
{
self
.
compute_device
}
"
f
"reduce_scatter_bucket_size_mb=
{
self
.
reduce_scatter_bucket_size_mb
}
, "
f
"clear_autocast_cache=
{
self
.
clear_autocast_cache
}
"
f
"force_input_to_fp32=
{
self
.
force_input_to_fp32
}
"
f
"offload_config=
{
self
.
offload_config
}
"
)
repr
=
(
f
"rank=
{
self
.
rank
}
, "
+
repr
+
f
"reshard_after_forward=
{
self
.
reshard_after_forward
}
, "
f
"compute_dtype=
{
self
.
compute_dtype
}
, "
f
"buffer_dtype=
{
self
.
buffer_dtype
}
, "
f
"fp32_reduce_scatter=
{
self
.
fp32_reduce_scatter
}
, "
f
"compute_device=
{
self
.
compute_device
}
"
f
"reduce_scatter_bucket_size_mb=
{
self
.
reduce_scatter_bucket_size_mb
}
, "
f
"clear_autocast_cache=
{
self
.
clear_autocast_cache
}
"
f
"force_input_to_fp32=
{
self
.
force_input_to_fp32
}
"
f
"offload_config=
{
self
.
offload_config
}
"
)
return
repr
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
@@ -1039,9 +1044,9 @@ class ShardedModel(nn.Module):
maybe_cast_buffers
()
return
state_dict
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
"""
Load a whole (unsharded) state_dict.
...
...
@@ -1094,7 +1099,6 @@ def _post_state_dict_hook(
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_state_dict_prefix
(
state_dict
,
prefix
,
prefix
+
"_zero3_module."
)
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
13886716
import
functools
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -6,32 +7,32 @@ import torch.distributed as dist
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine.ophooks
import
(
ShardGradHook
,
ShardParamHook
,
register_ophooks_recursively
)
from
colossalai.engine.ophooks
import
register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.sharded_grad
import
ShardedGradient
from
colossalai.zero.sharded_param
import
ShardedParam
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
chunk_and_pad
,
get_gradient_predivide_factor
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
get_gradient_predivide_factor
)
class
ShardedModelV2
(
nn
.
Module
):
def
__init__
(
self
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
reshard_after_forward
:
bool
=
True
,
mixed_precision
:
bool
=
False
,
fp32_reduce_scatter
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
):
def
__init__
(
self
,
module
:
nn
.
Module
,
shard_strategy
:
BaseShardStrategy
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
fp32_reduce_scatter
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
shard_param
:
bool
=
True
):
r
"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
...
...
@@ -44,22 +45,24 @@ class ShardedModelV2(nn.Module):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
#
The
module
has to be placed on GPU
self
.
module
=
module
.
cuda
()
#
Cast
module
to fp16 and cuda, in case user didn't use ZeroInitContext
self
.
module
=
module
.
half
().
cuda
()
# Shard the parameters at first
for
_
,
param
in
self
.
module
.
named_parameters
():
param
.
ca_attr
=
ShardedParam
(
param
)
param
.
ca_attr
.
shard
()
param
.
_sharded_grad
=
ShardedGradient
(
param
,
self
,
offload_config
)
self
.
shard_strategy
=
shard_strategy
self
.
shard_param
=
shard_param
# In case user didn't use ZeroInitContext
for
param
in
self
.
module
.
parameters
():
if
not
hasattr
(
param
,
'col_attr'
):
param
.
col_attr
=
ShardedParamV2
(
param
,
process_group
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
# Register hooks
register_ophooks_recursively
(
self
.
module
,
[
ShardParamHook
(),
ShardGradHook
(
)])
register_ophooks_recursively
(
self
.
module
,
[
ZeroHook
(
self
.
shard_strategy
)])
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
self
.
reshard_after_forward
=
reshard_after_forward
self
.
mixed_precision
=
mixed_precision
self
.
fp32_reduce_scatter
=
fp32_reduce_scatter
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
...
...
@@ -76,6 +79,7 @@ class ShardedModelV2(nn.Module):
self
.
_require_backward_grad_sync
:
bool
=
True
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
...
...
@@ -99,6 +103,7 @@ class ShardedModelV2(nn.Module):
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
...
...
@@ -107,11 +112,14 @@ class ShardedModelV2(nn.Module):
# sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
p
.
_sharded_grad
.
write_back
()
# Write grad back to p.grad and set p.col_attr.grad to None
p
.
grad
.
data
=
p
.
col_attr
.
grad
p
.
col_attr
.
grad
=
None
# In case some post bwd hook is not fired
for
p
in
self
.
module
.
parameters
():
if
not
p
.
ca_attr
.
is_sharded
:
p
.
ca_attr
.
shard
()
if
self
.
shard_param
:
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
])
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -119,7 +127,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param.
_sharded_
grad. This shard will align with the current GPU rank. For example::
GPUs to param.
col_attr.
grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
...
...
@@ -131,7 +139,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.
_sharded_
grad`, which ensures that
alignment is created by `param.
col_attr.
grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if
grad
is
None
:
...
...
@@ -142,7 +150,7 @@ class ShardedModelV2(nn.Module):
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
new_grad
=
grad
.
clone
()
if
self
.
mixed_precision
and
self
.
fp32_reduce_scatter
:
if
self
.
fp32_reduce_scatter
:
new_grad
.
data
=
new_grad
.
data
.
to
(
param
.
dtype
)
if
self
.
gradient_predivide_factor
>
1.0
:
# Average grad by world_size for consistency with PyTorch DDP.
...
...
@@ -161,13 +169,30 @@ class ShardedModelV2(nn.Module):
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if
self
.
mixed_precision
:
orig_param_grad_data
=
reduced_grad
.
data
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
dtype
=
param
.
ca_attr
.
origin_dtype
)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
param
.
_sharded_grad
.
reduce_scatter_callback
(
reduced_grad
)
# Make sure we store fp32 grad
reduced_grad
.
data
=
cast_tensor_to_fp32
(
reduced_grad
.
data
)
# Maybe offload
if
self
.
_cpu_offload
:
reduced_grad
.
data
=
reduced_grad
.
data
.
cpu
()
if
param
.
col_attr
.
grad
is
None
:
param
.
col_attr
.
grad
=
reduced_grad
.
data
else
:
param
.
col_attr
.
grad
.
add_
(
reduced_grad
.
data
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
prev_params
=
{}
for
p
in
self
.
module
.
parameters
():
prev_params
[
p
]
=
p
.
data
p
.
data
=
p
.
col_attr
.
data
.
payload
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
for
p
in
self
.
module
.
parameters
():
p
.
data
=
prev_params
[
p
]
return
gathered_state_dict
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
raise
NotImplementedError
colossalai/zero/sharded_param/sharded_param.py
View file @
13886716
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
import
torch
import
torch.distributed
as
dist
...
...
@@ -5,7 +7,6 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Union
,
Tuple
,
Optional
class
ShardedParamV2
(
object
):
...
...
@@ -14,12 +15,8 @@ class ShardedParamV2(object):
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
self
.
_data_sharded_tensor
=
ShardedTensor
(
param
.
data
,
process_group
)
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
self
.
_grad_sharded_tensor
=
ShardedTensor
(
param
.
grad
,
process_group
)
param
.
grad
=
None
else
:
self
.
_grad_sharded_tensor
=
None
self
.
_data_sharded_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
_grad_sharded_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.
...
...
@@ -30,27 +27,29 @@ class ShardedParamV2(object):
if
rm_torch_payload
:
self
.
remove_torch_payload
()
# Backward count for handle local grad accumulation
# This value will increment by 1 in every pre-bwd hook
# And will be reset to 0 in every final-bwd hook
self
.
bwd_count
=
0
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
@
property
def
data
(
self
):
return
self
.
_data_sharded_tensor
.
payload
@
data
.
setter
def
data
(
self
,
t
:
torch
.
Tensor
):
self
.
_data_sharded_tensor
.
payload
=
t
return
self
.
_data_sharded_tensor
@
property
def
grad
(
self
):
if
self
.
_grad_sharded_tensor
:
return
self
.
_grad_sharded_tensor
.
payload
else
:
return
None
return
self
.
_grad_sharded_tensor
@
grad
.
setter
def
grad
(
self
,
t
:
torch
.
Tensor
):
self
.
_grad_sharded_tensor
.
payload
=
t
self
.
_grad_sharded_tensor
=
t
@
property
def
param_is_sharded
(
self
):
return
self
.
_data_sharded_tensor
.
is_sharded
class
ShardedParam
(
object
):
...
...
tests/__init__.py
0 → 100644
View file @
13886716
tests/test_zero_data_parallel/common.py
View file @
13886716
...
...
@@ -45,16 +45,16 @@ class Net(nn.Module):
def
allclose
(
tensor_a
:
torch
.
Tensor
,
tensor_b
:
torch
.
Tensor
,
loose
=
False
)
->
bool
:
if
loose
:
return
torch
.
allclose
(
tensor_a
,
tensor_b
,
atol
=
1e-
3
,
rtol
=
1e-3
)
return
torch
.
allclose
(
tensor_a
,
tensor_b
,
atol
=
1e-
2
,
rtol
=
1e-3
)
return
torch
.
allclose
(
tensor_a
,
tensor_b
)
def
check_grads
(
model
,
zero_model
,
loose
=
False
):
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
assert
p
.
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
p
.
grad
,
zero_grad
,
loose
=
loose
)
LOGGER
.
info
(
torch
.
sum
(
p
.
grad
-
zero_grad
)
)
grad
=
p
.
grad
.
float
()
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
def
check_params
(
model
,
zero_model
,
loose
=
False
):
...
...
@@ -71,11 +71,11 @@ def check_grads_padding(model, zero_model, loose=False):
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
grad
=
chunks
[
rank
]
grad
=
chunks
[
rank
]
.
float
()
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
,
f
'
{
grad
}
vs
{
zero_grad
}
'
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
13886716
...
...
@@ -7,12 +7,14 @@ import colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
common
import
CONFIG
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
,
Net
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
...
@@ -25,11 +27,11 @@ def run_dist(rank, world_size, port):
shard_param
=
True
):
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'c
a
_attr'
)
assert
param
.
c
a
_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
c
a
_attr
.
_
data
_sharded_tensor
.
is_sharded
assert
param
.
c
a
_attr
.
data
.
device
.
type
==
'cuda'
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'c
ol
_attr'
)
assert
param
.
c
ol
_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
c
ol
_attr
.
data
.
is_sharded
assert
param
.
c
ol
_attr
.
data
.
payload
.
device
.
type
==
'cuda'
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
13886716
...
...
@@ -9,19 +9,21 @@ import pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
Net
,
check_grads
,
check_grads_padding
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
x
,
enable_autocast
=
False
):
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
x
)
loss
=
y
.
sum
(
)
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
...
...
@@ -31,19 +33,26 @@ def run_fwd_bwd(model, x, enable_autocast=False):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
Net
(
checkpoint
=
True
).
cuda
()
zero_model
=
copy
.
deepcopy
(
model
)
zero_model
=
ShardedModelV2
(
zero_model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
_
in
range
(
2
):
x
=
torch
.
rand
(
2
,
5
).
cuda
()
run_fwd_bwd
(
zero_model
,
x
,
False
)
run_fwd_bwd
(
model
,
x
,
False
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
().
half
().
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
)
else
:
check_grads
(
model
,
zero_model
)
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
13886716
...
...
@@ -4,18 +4,16 @@
from
copy
import
deepcopy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
,
allclose
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParam
,
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
tests.test_zero_data_parallel.common
import
CONFIG
,
Net
,
allclose
def
_run_shard_tensor
(
rank
,
world_size
,
port
):
...
...
@@ -47,7 +45,7 @@ def _run_shard_param_v2(rank, world_size, port):
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
allclose
(
sparam
.
data
,
param_ref
.
data
)
allclose
(
sparam
.
data
.
payload
,
param_ref
.
data
)
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
...
...
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
0 → 100644
View file @
13886716
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
zero_model
=
zero_model
()
model
=
copy
.
deepcopy
(
zero_model
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model_state_dict
=
zero_model
.
state_dict
()
for
n
,
p
in
model
.
named_parameters
():
p
.
data
=
model_state_dict
[
n
]
model
=
model
.
half
().
cuda
()
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
def
test_shard_model_v2
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
()
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
13886716
...
...
@@ -56,7 +56,7 @@ def run_dist(rank, world_size, port):
check_params
(
model
,
zero_model
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
def
test_sharded_optim_v2
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_zero_data_parallel/test_state_dict.py
0 → 100644
View file @
13886716
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
copy
import
deepcopy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
()
shard_strategy
=
TensorShardStrategy
()
model
=
model
.
half
().
cuda
()
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
zero_state_dict
=
zero_model
.
state_dict
()
for
key
,
val
in
model
.
state_dict
().
items
():
assert
torch
.
equal
(
val
,
zero_state_dict
[
key
])
@
pytest
.
mark
.
dist
def
test_zero_state_dict
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_state_dict
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment