Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
13886716
Commit
13886716
authored
Mar 08, 2022
by
ver217
Committed by
Frank Lee
Mar 11, 2022
Browse files
[zero] Update sharded model v2 using sharded param v2 (#323)
parent
799d105b
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
404 additions
and
203 deletions
+404
-203
colossalai/engine/ophooks/__init__.py
colossalai/engine/ophooks/__init__.py
+9
-11
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+58
-0
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+9
-8
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+4
-5
colossalai/zero/sharded_model/_zero3_utils.py
colossalai/zero/sharded_model/_zero3_utils.py
+9
-17
colossalai/zero/sharded_model/sharded_model.py
colossalai/zero/sharded_model/sharded_model.py
+64
-60
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+68
-43
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+16
-17
tests/__init__.py
tests/__init__.py
+0
-0
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+6
-6
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+10
-8
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+27
-18
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+7
-9
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
+73
-0
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+1
-1
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+43
-0
No files found.
colossalai/engine/ophooks/__init__.py
View file @
13886716
...
@@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
...
@@ -15,8 +15,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
if
type
(
outputs
)
is
tuple
:
if
type
(
outputs
)
is
tuple
:
touched_outputs
=
[]
touched_outputs
=
[]
for
output
in
outputs
:
for
output
in
outputs
:
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
output
)
backward_function
,
output
)
touched_outputs
.
append
(
touched_output
)
touched_outputs
.
append
(
touched_output
)
return
tuple
(
touched_outputs
)
return
tuple
(
touched_outputs
)
elif
type
(
outputs
)
is
torch
.
Tensor
:
elif
type
(
outputs
)
is
torch
.
Tensor
:
...
@@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
...
@@ -26,6 +25,7 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
class
PreBackwardFunction
(
torch
.
autograd
.
Function
):
class
PreBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
outputs
):
def
forward
(
ctx
,
module
,
pre_backward_function
,
outputs
):
ctx
.
module
=
module
ctx
.
module
=
module
...
@@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
...
@@ -41,6 +41,7 @@ class PreBackwardFunction(torch.autograd.Function):
class
PostBackwardFunction
(
torch
.
autograd
.
Function
):
class
PostBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
output
):
def
forward
(
ctx
,
module
,
pre_backward_function
,
output
):
ctx
.
module
=
module
ctx
.
module
=
module
...
@@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
...
@@ -60,9 +61,7 @@ class PostBackwardFunction(torch.autograd.Function):
return
(
None
,
None
)
+
args
return
(
None
,
None
)
+
args
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
ophook_list
:
List
[
BaseOpHook
]
=
None
,
name
:
str
=
""
):
ophook_list
:
List
[
BaseOpHook
]
=
None
,
name
:
str
=
""
):
r
"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
r
"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
has_children
=
False
has_children
=
False
...
@@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
...
@@ -72,8 +71,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
# Early return on modules with no parameters or buffers that
# Early return on modules with no parameters or buffers that
# are not in their children.
# are not in their children.
if
(
len
(
list
(
module
.
named_parameters
(
recurse
=
False
)))
==
0
if
(
len
(
list
(
module
.
named_parameters
(
recurse
=
False
)))
==
0
and
len
(
list
(
module
.
named_buffers
(
recurse
=
False
)))
==
0
):
and
len
(
list
(
module
.
named_buffers
(
recurse
=
False
)))
==
0
):
return
return
# return if the module has not childern.
# return if the module has not childern.
...
@@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
...
@@ -95,22 +93,22 @@ def register_ophooks_recursively(module: torch.nn.Module,
hook
.
post_fwd_exec
(
submodule
,
*
args
)
hook
.
post_fwd_exec
(
submodule
,
*
args
)
def
_pre_backward_module_hook
(
submodule
,
inputs
,
output
):
def
_pre_backward_module_hook
(
submodule
,
inputs
,
output
):
def
_run_before_backward_function
(
submodule
):
def
_run_before_backward_function
(
submodule
):
for
hook
in
ophook_list
:
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_bwd_exec
(
submodule
,
inputs
,
output
)
hook
.
pre_bwd_exec
(
submodule
,
inputs
,
output
)
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
_run_before_backward_function
,
output
)
_run_before_backward_function
,
output
)
def
_post_backward_module_hook
(
submodule
,
inputs
):
def
_post_backward_module_hook
(
submodule
,
inputs
):
def
_run_after_backward_function
(
submodule
):
def
_run_after_backward_function
(
submodule
):
for
hook
in
ophook_list
:
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_bwd_exec
(
submodule
,
inputs
)
hook
.
post_bwd_exec
(
submodule
,
inputs
)
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
_run_after_backward_function
,
inputs
)
_run_after_backward_function
,
inputs
)
module
.
register_forward_pre_hook
(
_pre_forward_module_hook
)
module
.
register_forward_pre_hook
(
_pre_forward_module_hook
)
module
.
register_forward_hook
(
_post_forward_module_hook
)
module
.
register_forward_hook
(
_post_forward_module_hook
)
...
...
colossalai/engine/ophooks/zero_hook.py
0 → 100644
View file @
13886716
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
@
OPHOOKS
.
register_module
class
ZeroHook
(
BaseOpHook
):
"""
A hook to process sharded param for ZeRO method.
"""
def
__init__
(
self
,
shard_strategy
:
BaseShardStrategy
):
super
().
__init__
()
self
.
shard_strategy
=
shard_strategy
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
param
.
data
=
param
.
col_attr
.
data
.
payload
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
param
.
data
=
param
.
col_attr
.
data
.
payload
# Store local accumulated grad shard
if
param
.
grad
is
not
None
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
grad
is
None
param
.
col_attr
.
grad
=
param
.
grad
.
data
param
.
grad
=
None
else
:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert
param
.
grad
.
shape
==
param
.
col_attr
.
data
.
origin_shape
param
.
col_attr
.
bwd_count
+=
1
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
def
pre_iter
(
self
):
pass
def
post_iter
(
self
):
pass
colossalai/zero/init_ctx/init_context.py
View file @
13886716
import
functools
import
functools
from
colossalai.utils.cuda
import
get_current_device
import
torch
import
torch
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
...
@@ -103,8 +104,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -103,8 +104,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'c
a
_attr'
)
assert
hasattr
(
param
,
'c
ol
_attr'
)
param
.
c
a
_attr
.
remove_torch_payload
()
param
.
c
ol
_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
del
self
.
initialized_param_list
...
@@ -113,7 +114,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -113,7 +114,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
"""
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
# avoid adapting a param to ShardedParam twice
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'c
a
_attr'
):
if
hasattr
(
param
,
'c
ol
_attr'
):
continue
continue
if
self
.
convert_cuda
:
if
self
.
convert_cuda
:
...
@@ -127,11 +128,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -127,11 +128,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
).
to
(
target_device
)
param
.
c
a
_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
c
ol
_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
self
.
initialized_param_list
.
append
(
param
)
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
a
_attr
.
_data_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
ol
_attr
.
_data_sharded_tensor
])
if
param
.
c
a
_attr
.
grad
and
self
.
shard_grad
:
if
param
.
c
ol
_attr
.
grad
and
self
.
shard_grad
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
a
_attr
.
_grad_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
c
ol
_attr
.
_grad_sharded_tensor
])
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
13886716
import
torch
import
torch.distributed
as
dist
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
class
TensorShardStrategy
(
BaseShardStrategy
):
class
TensorShardStrategy
(
BaseShardStrategy
):
...
@@ -38,7 +37,7 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -38,7 +37,7 @@ class TensorShardStrategy(BaseShardStrategy):
if
i
==
self
.
local_rank
:
if
i
==
self
.
local_rank
:
buffer_list
.
append
(
t
.
payload
.
cuda
())
buffer_list
.
append
(
t
.
payload
.
cuda
())
else
:
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
).
cuda
())
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
).
cuda
())
torch
.
distributed
.
all_gather
(
buffer_list
,
torch
.
distributed
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
buffer_list
[
self
.
local_rank
],
...
...
colossalai/zero/sharded_model/_zero3_utils.py
View file @
13886716
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
...
@@ -42,27 +41,21 @@ def free_storage(data: torch.Tensor) -> None:
...
@@ -42,27 +41,21 @@ def free_storage(data: torch.Tensor) -> None:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
alloc_storage
(
data
:
torch
.
Tensor
,
size
:
torch
.
Size
)
->
None
:
def
alloc_storage
(
data
:
torch
.
Tensor
,
size
:
torch
.
Size
)
->
None
:
"""Allocate storage for a tensor."""
"""Allocate storage for a tensor."""
if
data
.
storage
().
size
()
==
size
.
numel
():
# no need to reallocate
if
data
.
storage
().
size
()
==
size
.
numel
():
# no need to reallocate
return
return
assert
data
.
storage
().
size
()
==
0
assert
data
.
storage
().
size
()
==
0
data
.
storage
().
resize_
(
size
.
numel
())
data
.
storage
().
resize_
(
size
.
numel
())
def
cast_trensor_to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
cast_tensor_to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
tensor
.
dtype
is
torch
.
float32
:
if
torch
.
is_floating_point
(
tensor
)
and
tensor
.
dtype
is
torch
.
float32
:
out
=
tensor
.
half
()
return
tensor
.
half
()
if
tensor
.
is_leaf
:
out
.
requires_grad
=
tensor
.
requires_grad
return
out
return
tensor
return
tensor
def
cast_trensor_to_fp32
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
cast_tensor_to_fp32
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
tensor
.
dtype
is
torch
.
float16
:
if
torch
.
is_floating_point
(
tensor
)
and
tensor
.
dtype
is
torch
.
float16
:
out
=
tensor
.
float
()
return
tensor
.
float
()
if
tensor
.
is_leaf
:
out
.
requires_grad
=
tensor
.
requires_grad
return
out
return
tensor
return
tensor
...
@@ -102,9 +95,8 @@ def assert_in_engine(cond: Any, s: Any) -> None:
...
@@ -102,9 +95,8 @@ def assert_in_engine(cond: Any, s: Any) -> None:
raise
AssertionError
raise
AssertionError
def
replace_state_dict_prefix
(
def
replace_state_dict_prefix
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
)
->
None
:
"""
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Replace all keys that match a given old_prefix with a new_prefix (in-place).
...
...
colossalai/zero/sharded_model/sharded_model.py
View file @
13886716
...
@@ -5,8 +5,7 @@ import os
...
@@ -5,8 +5,7 @@ import os
import
traceback
import
traceback
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Union
)
Set
,
Union
)
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -15,16 +14,14 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -15,16 +14,14 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
.param_manager
import
Zero3ParameterManager
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
(
apply_to_tensors
,
assert_in_engine
,
from
._zero3_utils
import
(
apply_to_tensors
,
assert_in_engine
,
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_float_arguments
,
cast_trensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
get_shard
,
cast_trensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
get_shard
,
replace_state_dict_prefix
)
replace_state_dict_prefix
)
from
.param_manager
import
Zero3ParameterManager
from
.reduce_scatter
import
ReduceScatterBucketer
from
.reduce_scatter
import
ReduceScatterBucketer
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
...
@@ -41,11 +38,13 @@ class TrainingState(Enum):
...
@@ -41,11 +38,13 @@ class TrainingState(Enum):
POST_BACKWARD
=
auto
()
POST_BACKWARD
=
auto
()
GATHER_FULL_PARAMS
=
auto
()
GATHER_FULL_PARAMS
=
auto
()
# TODO: Add clip_grad_norm_
# TODO: Add clip_grad_norm_
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
class
ShardedModel
(
nn
.
Module
):
class
ShardedModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
module
:
nn
.
Module
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
...
@@ -96,8 +95,10 @@ class ShardedModel(nn.Module):
...
@@ -96,8 +95,10 @@ class ShardedModel(nn.Module):
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
# However, if you set gradient_predivide_factor to None
self
.
gradient_predivide_factor
:
float
=
gradient_predivide_factor
if
gradient_predivide_factor
is
not
None
else
\
# we will set gradient_predivide_factor to a value >= 1.0 automatically
self
.
gradient_predivide_factor
:
float
=
gradient_predivide_factor
if
\
gradient_predivide_factor
is
not
None
else
\
get_gradient_predivide_factor
(
self
.
world_size
)
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
...
@@ -111,8 +112,12 @@ class ShardedModel(nn.Module):
...
@@ -111,8 +112,12 @@ class ShardedModel(nn.Module):
self
.
module
=
module
self
.
module
=
module
self
.
param_manager
=
Zero3ParameterManager
(
module
,
process_group
=
self
.
process_group
,
mixed_precision
=
self
.
mixed_precision
,
self
.
param_manager
=
Zero3ParameterManager
(
module
,
flatten_parameters
=
flatten_parameters
,
compute_dtype
=
self
.
compute_dtype
,
compute_device
=
self
.
compute_device
,
process_group
=
self
.
process_group
,
mixed_precision
=
self
.
mixed_precision
,
flatten_parameters
=
flatten_parameters
,
compute_dtype
=
self
.
compute_dtype
,
compute_device
=
self
.
compute_device
,
offload_config
=
offload_config
)
offload_config
=
offload_config
)
self
.
_reset_lazy_init_info
()
self
.
_reset_lazy_init_info
()
...
@@ -145,13 +150,13 @@ class ShardedModel(nn.Module):
...
@@ -145,13 +150,13 @@ class ShardedModel(nn.Module):
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
# the conversion).
if
self
.
_is_root
and
self
.
mixed_precision
:
if
self
.
_is_root
and
self
.
mixed_precision
:
args
,
kwargs
=
cast_float_arguments
(
cast_t
r
ensor_to_fp16
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
# If enabled, convert the input to FP32 if we are in full precision.
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
# which mean autograd needs to go through the conversion.
if
self
.
force_input_to_fp32
and
not
self
.
mixed_precision
:
if
self
.
force_input_to_fp32
and
not
self
.
mixed_precision
:
args
,
kwargs
=
cast_float_arguments
(
cast_t
r
ensor_to_fp32
,
*
args
,
**
kwargs
)
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp32
,
*
args
,
**
kwargs
)
# All-gather full parameters. This will also transfer FP32 parameters to
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
...
@@ -201,10 +206,9 @@ class ShardedModel(nn.Module):
...
@@ -201,10 +206,9 @@ class ShardedModel(nn.Module):
input_tensor
=
torch
.
ones
(
1
).
to
(
self
.
compute_device
)
input_tensor
=
torch
.
ones
(
1
).
to
(
self
.
compute_device
)
output
=
list
(
torch
.
zeros
(
self
.
world_size
).
to
(
self
.
compute_device
).
chunk
(
self
.
world_size
))
output
=
list
(
torch
.
zeros
(
self
.
world_size
).
to
(
self
.
compute_device
).
chunk
(
self
.
world_size
))
dist
.
all_gather
(
output
,
input_tensor
,
group
=
self
.
process_group
)
dist
.
all_gather
(
output
,
input_tensor
,
group
=
self
.
process_group
)
assert
torch
.
cat
(
output
).
sum
()
==
float
(
self
.
world_size
),
(
assert
torch
.
cat
(
output
).
sum
()
==
float
(
f
"found
{
torch
.
cat
(
output
).
sum
()
}
devices in process group but "
self
.
world_size
),
(
f
"found
{
torch
.
cat
(
output
).
sum
()
}
devices in process group but "
f
"world_size=
{
self
.
world_size
}
. Check torch.cuda.set_device is called properly"
f
"world_size=
{
self
.
world_size
}
. Check torch.cuda.set_device is called properly"
)
)
def
_reset_lazy_init_info
(
self
)
->
None
:
def
_reset_lazy_init_info
(
self
)
->
None
:
self
.
_is_root
:
Optional
[
bool
]
=
None
self
.
_is_root
:
Optional
[
bool
]
=
None
...
@@ -277,9 +281,10 @@ class ShardedModel(nn.Module):
...
@@ -277,9 +281,10 @@ class ShardedModel(nn.Module):
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m
.
no_broadcast_optim_state
=
m
.
no_broadcast_optim_state
or
(
m
.
no_broadcast_optim_state
=
m
.
no_broadcast_optim_state
or
\
(
m
.
world_size
==
1
)
and
(
m
.
world_size
<
self
.
world_size
)
and
(
m
.
process_group
!=
self
.
process_group
)
((
m
.
world_size
==
1
)
)
and
(
m
.
world_size
<
self
.
world_size
)
and
(
m
.
process_group
!=
self
.
process_group
))
def
_setup_streams
(
self
)
->
None
:
def
_setup_streams
(
self
)
->
None
:
"""Create streams to overlap data transfer and computation."""
"""Create streams to overlap data transfer and computation."""
...
@@ -330,9 +335,10 @@ class ShardedModel(nn.Module):
...
@@ -330,9 +335,10 @@ class ShardedModel(nn.Module):
else
:
else
:
self
.
_streams
[
"all_gather"
].
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_streams
[
"all_gather"
].
wait_stream
(
torch
.
cuda
.
current_stream
())
def
_cast_buffers
(
def
_cast_buffers
(
self
,
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
memo
:
Optional
[
Set
]
=
None
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
memo
:
Optional
[
Set
]
=
None
)
->
None
:
"""Move all buffers to the given *device* and *dtype*.
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
If *device* or *dtype* are not given, then they will default to
...
@@ -398,7 +404,7 @@ class ShardedModel(nn.Module):
...
@@ -398,7 +404,7 @@ class ShardedModel(nn.Module):
outputs: new outputs with hooks registered if they requires gradient.
outputs: new outputs with hooks registered if they requires gradient.
"""
"""
if
not
torch
.
is_grad_enabled
():
if
not
torch
.
is_grad_enabled
():
return
outputs
# don't register hooks if grad isn't enabled
return
outputs
# don't register hooks if grad isn't enabled
if
self
.
_is_root
:
if
self
.
_is_root
:
# This actually means that only root instance has
# This actually means that only root instance has
...
@@ -523,7 +529,7 @@ class ShardedModel(nn.Module):
...
@@ -523,7 +529,7 @@ class ShardedModel(nn.Module):
a new hook, which is needed for a new forward pass.
a new hook, which is needed for a new forward pass.
"""
"""
if
not
torch
.
is_grad_enabled
():
if
not
torch
.
is_grad_enabled
():
return
# don't register grad hooks if grad isn't enabled
return
# don't register grad hooks if grad isn't enabled
for
p
in
self
.
params
:
for
p
in
self
.
params
:
if
p
.
requires_grad
:
if
p
.
requires_grad
:
if
hasattr
(
p
,
"zero_shard_bwd_hook"
):
if
hasattr
(
p
,
"zero_shard_bwd_hook"
):
...
@@ -612,7 +618,8 @@ class ShardedModel(nn.Module):
...
@@ -612,7 +618,8 @@ class ShardedModel(nn.Module):
if
param
.
zero_is_sharded
:
if
param
.
zero_is_sharded
:
assert
self
.
_reducer
is
not
None
assert
self
.
_reducer
is
not
None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times it's possible that multiple
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times
# it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
# matter, neglecting rounding.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
...
@@ -628,9 +635,9 @@ class ShardedModel(nn.Module):
...
@@ -628,9 +635,9 @@ class ShardedModel(nn.Module):
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
)
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
)
grad_chunks
=
chunk_and_pad
(
orig_grad_data
,
self
.
reduce_scatter_process_group
.
size
())
grad_chunks
=
chunk_and_pad
(
orig_grad_data
,
self
.
reduce_scatter_process_group
.
size
())
self
.
_reducer
.
reduce_scatter_async
(
self
.
_reducer
.
reduce_scatter_async
(
grad_chunks
,
grad_chunks
,
group
=
self
.
reduce_scatter_process_group
,
callback_fn
=
callback_fn
group
=
self
.
reduce_scatter_process_group
,
)
callback_fn
=
callback_fn
)
else
:
else
:
# Currently the only way for _is_sharded to be False is if
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# world_size == 1. This could be relaxed in the future, in which
...
@@ -667,8 +674,9 @@ class ShardedModel(nn.Module):
...
@@ -667,8 +674,9 @@ class ShardedModel(nn.Module):
param
.
zero_saved_grad_shard
=
reduced_grad
.
data
param
.
zero_saved_grad_shard
=
reduced_grad
.
data
else
:
else
:
assert
(
assert
(
param
.
zero_saved_grad_shard
.
shape
==
reduced_grad
.
shape
param
.
zero_saved_grad_shard
.
shape
==
reduced_grad
.
shape
),
f
"
{
param
.
zero_saved_grad_shard
.
shape
}
\
),
f
"
{
param
.
zero_saved_grad_shard
.
shape
}
vs
{
reduced_grad
.
shape
}
"
vs
{
reduced_grad
.
shape
}
"
param
.
zero_saved_grad_shard
.
data
+=
reduced_grad
.
data
param
.
zero_saved_grad_shard
.
data
+=
reduced_grad
.
data
reduced_grad
=
param
.
zero_saved_grad_shard
.
data
reduced_grad
=
param
.
zero_saved_grad_shard
.
data
else
:
else
:
...
@@ -717,7 +725,7 @@ class ShardedModel(nn.Module):
...
@@ -717,7 +725,7 @@ class ShardedModel(nn.Module):
# Flush any unreduced buckets in the post_backward stream.
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
assert_in_engine
(
self
.
_reducer
is
not
None
,
"FinalBackwardHook: reducer is None"
)
assert_in_engine
(
self
.
_reducer
is
not
None
,
"FinalBackwardHook: reducer is None"
)
assert
self
.
_reducer
is
not
None
# make mypy happy
assert
self
.
_reducer
is
not
None
# make mypy happy
self
.
_reducer
.
flush
()
self
.
_reducer
.
flush
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"post_backward"
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"post_backward"
])
if
self
.
_cpu_offload
:
if
self
.
_cpu_offload
:
...
@@ -753,7 +761,8 @@ class ShardedModel(nn.Module):
...
@@ -753,7 +761,8 @@ class ShardedModel(nn.Module):
elif
hasattr
(
p
,
"zero_saved_grad_shard"
):
elif
hasattr
(
p
,
"zero_saved_grad_shard"
):
assert_in_engine
(
assert_in_engine
(
p
.
device
==
p
.
zero_saved_grad_shard
.
device
,
p
.
device
==
p
.
zero_saved_grad_shard
.
device
,
f
"FinalBackwardHook: incorrect saved_grad_shard device
{
p
.
device
}
vs
{
p
.
zero_saved_grad_shard
.
device
}
"
,
f
"FinalBackwardHook: incorrect saved_grad_shard device
\
{
p
.
device
}
vs
{
p
.
zero_saved_grad_shard
.
device
}
"
,
)
)
p
.
grad
=
p
.
zero_saved_grad_shard
p
.
grad
=
p
.
zero_saved_grad_shard
elif
hasattr
(
p
,
'zero_saved_grad'
):
elif
hasattr
(
p
,
'zero_saved_grad'
):
...
@@ -765,7 +774,7 @@ class ShardedModel(nn.Module):
...
@@ -765,7 +774,7 @@ class ShardedModel(nn.Module):
delattr
(
p
,
"zero_saved_grad"
)
delattr
(
p
,
"zero_saved_grad"
)
# Update root and nested ShardedModel's hooks and flags.
# Update root and nested ShardedModel's hooks and flags.
for
m
in
self
.
modules
():
# includes self
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
ShardedModel
):
if
isinstance
(
m
,
ShardedModel
):
_finalize_parameters
(
m
)
_finalize_parameters
(
m
)
m
.
_pre_backward_hook_has_run
=
False
m
.
_pre_backward_hook_has_run
=
False
...
@@ -796,7 +805,7 @@ class ShardedModel(nn.Module):
...
@@ -796,7 +805,7 @@ class ShardedModel(nn.Module):
self
.
_output_pre_backward_hook_registered
is
not
None
,
self
.
_output_pre_backward_hook_registered
is
not
None
,
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None"
,
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None"
,
)
)
assert
self
.
_output_pre_backward_hook_registered
is
not
None
# make mypy happy
assert
self
.
_output_pre_backward_hook_registered
is
not
None
# make mypy happy
self
.
_output_pre_backward_hook_registered
.
clear
()
self
.
_output_pre_backward_hook_registered
.
clear
()
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -908,9 +917,9 @@ class ShardedModel(nn.Module):
...
@@ -908,9 +917,9 @@ class ShardedModel(nn.Module):
state
[
"is_sharded"
]
=
[
p
.
zero_is_sharded
for
p
in
self
.
params
]
state
[
"is_sharded"
]
=
[
p
.
zero_is_sharded
for
p
in
self
.
params
]
state
[
"orig_sizes"
]
=
[
p
.
zero_orig_size
for
p
in
self
.
params
]
state
[
"orig_sizes"
]
=
[
p
.
zero_orig_size
for
p
in
self
.
params
]
if
state
[
"process_group"
]
is
not
None
:
if
state
[
"process_group"
]
is
not
None
:
state
[
"process_group"
]
=
"MISSING"
# process_group isn't pickleable
state
[
"process_group"
]
=
"MISSING"
# process_group isn't pickleable
if
state
[
"process_group_reduce_scatter"
]
is
not
None
:
if
state
[
"process_group_reduce_scatter"
]
is
not
None
:
state
[
"process_group_reduce_scatter"
]
=
"MISSING"
# process_group_reduce_scatter isn't pickleable
state
[
"process_group_reduce_scatter"
]
=
"MISSING"
# process_group_reduce_scatter isn't pickleable
self
.
_reset_lazy_init_info
()
self
.
_reset_lazy_init_info
()
return
state
return
state
...
@@ -920,7 +929,7 @@ class ShardedModel(nn.Module):
...
@@ -920,7 +929,7 @@ class ShardedModel(nn.Module):
def
fixup
(
p
:
Parameter
,
is_sharded
:
bool
,
size
:
torch
.
Size
)
->
Parameter
:
def
fixup
(
p
:
Parameter
,
is_sharded
:
bool
,
size
:
torch
.
Size
)
->
Parameter
:
assert
isinstance
(
p
,
Parameter
)
assert
isinstance
(
p
,
Parameter
)
p
.
data
=
p
.
data
.
clone
()
# move tensors out of shared memory
p
.
data
=
p
.
data
.
clone
()
# move tensors out of shared memory
p
.
zero_is_sharded
=
is_sharded
p
.
zero_is_sharded
=
is_sharded
p
.
zero_orig_size
=
size
p
.
zero_orig_size
=
size
return
p
return
p
...
@@ -958,7 +967,7 @@ class ShardedModel(nn.Module):
...
@@ -958,7 +967,7 @@ class ShardedModel(nn.Module):
# This instance may wrap other ShardedModel instances and we
# This instance may wrap other ShardedModel instances and we
# need to set all of them to accumulate gradients.
# need to set all of them to accumulate gradients.
old_flags
=
[]
old_flags
=
[]
for
m
in
self
.
modules
():
# includes self
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
ShardedModel
):
if
isinstance
(
m
,
ShardedModel
):
old_flags
.
append
((
m
,
m
.
_require_backward_grad_sync
))
old_flags
.
append
((
m
,
m
.
_require_backward_grad_sync
))
m
.
_require_backward_grad_sync
=
False
m
.
_require_backward_grad_sync
=
False
...
@@ -986,22 +995,18 @@ class ShardedModel(nn.Module):
...
@@ -986,22 +995,18 @@ class ShardedModel(nn.Module):
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
repr
=
(
repr
=
(
f
"world_size=
{
self
.
world_size
}
, "
f
"world_size=
{
self
.
world_size
}
, "
f
"mixed_precision=
{
self
.
mixed_precision
}
, "
)
f
"mixed_precision=
{
self
.
mixed_precision
}
, "
)
if
self
.
verbose
:
if
self
.
verbose
:
repr
=
(
repr
=
(
f
"rank=
{
self
.
rank
}
, "
+
repr
+
f
"reshard_after_forward=
{
self
.
reshard_after_forward
}
, "
f
"rank=
{
self
.
rank
}
, "
+
repr
+
f
"reshard_after_forward=
{
self
.
reshard_after_forward
}
, "
f
"compute_dtype=
{
self
.
compute_dtype
}
, "
f
"compute_dtype=
{
self
.
compute_dtype
}
, "
f
"buffer_dtype=
{
self
.
buffer_dtype
}
, "
f
"buffer_dtype=
{
self
.
buffer_dtype
}
, "
f
"fp32_reduce_scatter=
{
self
.
fp32_reduce_scatter
}
, "
f
"fp32_reduce_scatter=
{
self
.
fp32_reduce_scatter
}
, "
f
"compute_device=
{
self
.
compute_device
}
"
f
"compute_device=
{
self
.
compute_device
}
"
f
"reduce_scatter_bucket_size_mb=
{
self
.
reduce_scatter_bucket_size_mb
}
, "
f
"reduce_scatter_bucket_size_mb=
{
self
.
reduce_scatter_bucket_size_mb
}
, "
f
"clear_autocast_cache=
{
self
.
clear_autocast_cache
}
"
f
"clear_autocast_cache=
{
self
.
clear_autocast_cache
}
"
f
"force_input_to_fp32=
{
self
.
force_input_to_fp32
}
"
f
"force_input_to_fp32=
{
self
.
force_input_to_fp32
}
"
f
"offload_config=
{
self
.
offload_config
}
"
)
f
"offload_config=
{
self
.
offload_config
}
"
)
return
repr
return
repr
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
@@ -1039,9 +1044,9 @@ class ShardedModel(nn.Module):
...
@@ -1039,9 +1044,9 @@ class ShardedModel(nn.Module):
maybe_cast_buffers
()
maybe_cast_buffers
()
return
state_dict
return
state_dict
def
load_state_dict
(
def
load_state_dict
(
self
,
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
)
->
NamedTuple
:
strict
:
bool
=
True
)
->
NamedTuple
:
"""
"""
Load a whole (unsharded) state_dict.
Load a whole (unsharded) state_dict.
...
@@ -1094,7 +1099,6 @@ def _post_state_dict_hook(
...
@@ -1094,7 +1099,6 @@ def _post_state_dict_hook(
return
state_dict
return
state_dict
def
_pre_load_state_dict_hook
(
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
*
args
:
Any
)
->
None
:
)
->
None
:
replace_state_dict_prefix
(
state_dict
,
prefix
,
prefix
+
"_zero3_module."
)
replace_state_dict_prefix
(
state_dict
,
prefix
,
prefix
+
"_zero3_module."
)
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
13886716
import
functools
import
functools
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch
...
@@ -6,32 +7,32 @@ import torch.distributed as dist
...
@@ -6,32 +7,32 @@ import torch.distributed as dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine.ophooks
import
(
ShardGradHook
,
ShardParamHook
,
register_ophooks_recursively
)
from
colossalai.engine.ophooks
import
register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.sharded_grad
import
ShardedGradient
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParam
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
chunk_and_pad
,
get_gradient_predivide_factor
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
get_gradient_predivide_factor
)
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
module
:
nn
.
Module
,
module
:
nn
.
Module
,
shard_strategy
:
BaseShardStrategy
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
reshard_after_forward
:
bool
=
True
,
fp32_reduce_scatter
:
bool
=
False
,
mixed_precision
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
fp32_reduce_scatter
:
bool
=
False
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
offload_config
:
Optional
[
dict
]
=
None
,
shard_param
:
bool
=
True
):
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
):
r
"""
r
"""
A demo to reconfigure zero1 shared_model.
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
Currently do not consider the Optimizer States.
...
@@ -44,22 +45,24 @@ class ShardedModelV2(nn.Module):
...
@@ -44,22 +45,24 @@ class ShardedModelV2(nn.Module):
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
#
The
module
has to be placed on GPU
#
Cast
module
to fp16 and cuda, in case user didn't use ZeroInitContext
self
.
module
=
module
.
cuda
()
self
.
module
=
module
.
half
().
cuda
()
# Shard the parameters at first
self
.
shard_strategy
=
shard_strategy
for
_
,
param
in
self
.
module
.
named_parameters
():
self
.
shard_param
=
shard_param
param
.
ca_attr
=
ShardedParam
(
param
)
param
.
ca_attr
.
shard
()
# In case user didn't use ZeroInitContext
param
.
_sharded_grad
=
ShardedGradient
(
param
,
self
,
offload_config
)
for
param
in
self
.
module
.
parameters
():
if
not
hasattr
(
param
,
'col_attr'
):
param
.
col_attr
=
ShardedParamV2
(
param
,
process_group
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
# Register hooks
# Register hooks
register_ophooks_recursively
(
self
.
module
,
[
ShardParamHook
(),
ShardGradHook
(
)])
register_ophooks_recursively
(
self
.
module
,
[
ZeroHook
(
self
.
shard_strategy
)])
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
self
.
reshard_after_forward
=
reshard_after_forward
self
.
mixed_precision
=
mixed_precision
self
.
fp32_reduce_scatter
=
fp32_reduce_scatter
self
.
fp32_reduce_scatter
=
fp32_reduce_scatter
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
...
@@ -76,6 +79,7 @@ class ShardedModelV2(nn.Module):
...
@@ -76,6 +79,7 @@ class ShardedModelV2(nn.Module):
self
.
_require_backward_grad_sync
:
bool
=
True
self
.
_require_backward_grad_sync
:
bool
=
True
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
return
outputs
...
@@ -99,6 +103,7 @@ class ShardedModelV2(nn.Module):
...
@@ -99,6 +103,7 @@ class ShardedModelV2(nn.Module):
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
self
.
reducer
.
free
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
p
.
col_attr
.
bwd_count
=
0
if
not
p
.
requires_grad
:
if
not
p
.
requires_grad
:
continue
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
...
@@ -107,11 +112,14 @@ class ShardedModelV2(nn.Module):
...
@@ -107,11 +112,14 @@ class ShardedModelV2(nn.Module):
# sync passes, if desired.
# sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
continue
continue
p
.
_sharded_grad
.
write_back
()
# Write grad back to p.grad and set p.col_attr.grad to None
p
.
grad
.
data
=
p
.
col_attr
.
grad
p
.
col_attr
.
grad
=
None
# In case some post bwd hook is not fired
# In case some post bwd hook is not fired
for
p
in
self
.
module
.
parameters
():
if
self
.
shard_param
:
if
not
p
.
ca_attr
.
is_sharded
:
for
p
in
self
.
module
.
parameters
():
p
.
ca_attr
.
shard
()
if
not
p
.
col_attr
.
param_is_sharded
:
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
])
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -119,7 +127,7 @@ class ShardedModelV2(nn.Module):
...
@@ -119,7 +127,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
a single shard of the summed gradient across all
GPUs to param.
_sharded_
grad. This shard will align with the current GPU rank. For example::
GPUs to param.
col_attr.
grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #0): [1, 2, 3, 4]
...
@@ -131,7 +139,7 @@ class ShardedModelV2(nn.Module):
...
@@ -131,7 +139,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.
_sharded_
grad`, which ensures that
alignment is created by `param.
col_attr.
grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
the local optimizer only sees the relevant parameter shard.
"""
"""
if
grad
is
None
:
if
grad
is
None
:
...
@@ -142,7 +150,7 @@ class ShardedModelV2(nn.Module):
...
@@ -142,7 +150,7 @@ class ShardedModelV2(nn.Module):
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
new_grad
=
grad
.
clone
()
new_grad
=
grad
.
clone
()
if
self
.
mixed_precision
and
self
.
fp32_reduce_scatter
:
if
self
.
fp32_reduce_scatter
:
new_grad
.
data
=
new_grad
.
data
.
to
(
param
.
dtype
)
new_grad
.
data
=
new_grad
.
data
.
to
(
param
.
dtype
)
if
self
.
gradient_predivide_factor
>
1.0
:
if
self
.
gradient_predivide_factor
>
1.0
:
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
...
@@ -161,13 +169,30 @@ class ShardedModelV2(nn.Module):
...
@@ -161,13 +169,30 @@ class ShardedModelV2(nn.Module):
if
self
.
gradient_postdivide_factor
>
1
:
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# Make sure we store fp32 grad
# non-blocking. The downside is a bit more D2H transfer in that case.
reduced_grad
.
data
=
cast_tensor_to_fp32
(
reduced_grad
.
data
)
if
self
.
mixed_precision
:
orig_param_grad_data
=
reduced_grad
.
data
# Maybe offload
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
dtype
=
param
.
ca_attr
.
origin_dtype
)
if
self
.
_cpu_offload
:
# Don't let this memory get reused until after the transfer.
reduced_grad
.
data
=
reduced_grad
.
data
.
cpu
()
orig_param_grad_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
if
param
.
col_attr
.
grad
is
None
:
param
.
_sharded_grad
.
reduce_scatter_callback
(
reduced_grad
)
param
.
col_attr
.
grad
=
reduced_grad
.
data
else
:
param
.
col_attr
.
grad
.
add_
(
reduced_grad
.
data
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
prev_params
=
{}
for
p
in
self
.
module
.
parameters
():
prev_params
[
p
]
=
p
.
data
p
.
data
=
p
.
col_attr
.
data
.
payload
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
for
p
in
self
.
module
.
parameters
():
p
.
data
=
prev_params
[
p
]
return
gathered_state_dict
def
load_state_dict
(
self
,
state_dict
:
'OrderedDict[str, torch.Tensor]'
,
strict
:
bool
=
True
):
raise
NotImplementedError
colossalai/zero/sharded_param/sharded_param.py
View file @
13886716
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
import
numpy
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -5,7 +7,6 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -5,7 +7,6 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_param
import
ShardedTensor
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Union
,
Tuple
,
Optional
class
ShardedParamV2
(
object
):
class
ShardedParamV2
(
object
):
...
@@ -14,12 +15,8 @@ class ShardedParamV2(object):
...
@@ -14,12 +15,8 @@ class ShardedParamV2(object):
param
:
torch
.
nn
.
Parameter
,
param
:
torch
.
nn
.
Parameter
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
rm_torch_payload
=
False
)
->
None
:
self
.
_data_sharded_tensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
_data_sharded_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
self
.
_grad_sharded_tensor
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_grad_sharded_tensor
=
ShardedTensor
(
param
.
grad
,
process_group
)
param
.
grad
=
None
else
:
self
.
_grad_sharded_tensor
=
None
# make sure the shared param is the only owner of payload
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.
# The param.data maybe used to init the other part of the model.
...
@@ -30,27 +27,29 @@ class ShardedParamV2(object):
...
@@ -30,27 +27,29 @@ class ShardedParamV2(object):
if
rm_torch_payload
:
if
rm_torch_payload
:
self
.
remove_torch_payload
()
self
.
remove_torch_payload
()
# Backward count for handle local grad accumulation
# This value will increment by 1 in every pre-bwd hook
# And will be reset to 0 in every final-bwd hook
self
.
bwd_count
=
0
def
remove_torch_payload
(
self
):
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
@
property
@
property
def
data
(
self
):
def
data
(
self
):
return
self
.
_data_sharded_tensor
.
payload
return
self
.
_data_sharded_tensor
@
data
.
setter
def
data
(
self
,
t
:
torch
.
Tensor
):
self
.
_data_sharded_tensor
.
payload
=
t
@
property
@
property
def
grad
(
self
):
def
grad
(
self
):
if
self
.
_grad_sharded_tensor
:
return
self
.
_grad_sharded_tensor
return
self
.
_grad_sharded_tensor
.
payload
else
:
return
None
@
grad
.
setter
@
grad
.
setter
def
grad
(
self
,
t
:
torch
.
Tensor
):
def
grad
(
self
,
t
:
torch
.
Tensor
):
self
.
_grad_sharded_tensor
.
payload
=
t
self
.
_grad_sharded_tensor
=
t
@
property
def
param_is_sharded
(
self
):
return
self
.
_data_sharded_tensor
.
is_sharded
class
ShardedParam
(
object
):
class
ShardedParam
(
object
):
...
...
tests/__init__.py
0 → 100644
View file @
13886716
tests/test_zero_data_parallel/common.py
View file @
13886716
...
@@ -45,16 +45,16 @@ class Net(nn.Module):
...
@@ -45,16 +45,16 @@ class Net(nn.Module):
def
allclose
(
tensor_a
:
torch
.
Tensor
,
tensor_b
:
torch
.
Tensor
,
loose
=
False
)
->
bool
:
def
allclose
(
tensor_a
:
torch
.
Tensor
,
tensor_b
:
torch
.
Tensor
,
loose
=
False
)
->
bool
:
if
loose
:
if
loose
:
return
torch
.
allclose
(
tensor_a
,
tensor_b
,
atol
=
1e-
3
,
rtol
=
1e-3
)
return
torch
.
allclose
(
tensor_a
,
tensor_b
,
atol
=
1e-
2
,
rtol
=
1e-3
)
return
torch
.
allclose
(
tensor_a
,
tensor_b
)
return
torch
.
allclose
(
tensor_a
,
tensor_b
)
def
check_grads
(
model
,
zero_model
,
loose
=
False
):
def
check_grads
(
model
,
zero_model
,
loose
=
False
):
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
zero_grad
=
zero_p
.
grad
.
clone
().
to
(
p
.
device
)
assert
p
.
grad
.
dtype
==
zero_grad
.
dtype
grad
=
p
.
grad
.
float
()
assert
allclose
(
p
.
grad
,
zero_grad
,
loose
=
loose
)
assert
grad
.
dtype
==
zero_grad
.
dtype
LOGGER
.
info
(
torch
.
sum
(
p
.
grad
-
zero_grad
)
)
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
def
check_params
(
model
,
zero_model
,
loose
=
False
):
def
check_params
(
model
,
zero_model
,
loose
=
False
):
...
@@ -71,11 +71,11 @@ def check_grads_padding(model, zero_model, loose=False):
...
@@ -71,11 +71,11 @@ def check_grads_padding(model, zero_model, loose=False):
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
if
rank
>=
len
(
chunks
):
continue
continue
grad
=
chunks
[
rank
]
grad
=
chunks
[
rank
]
.
float
()
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
if
zero_grad
.
size
(
0
)
>
grad
.
size
(
0
):
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
zero_grad
=
zero_grad
[:
grad
.
size
(
0
)]
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
grad
.
dtype
==
zero_grad
.
dtype
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
assert
allclose
(
grad
,
zero_grad
,
loose
=
loose
)
,
f
'
{
grad
}
vs
{
zero_grad
}
'
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
def
check_params_padding
(
model
,
zero_model
,
loose
=
False
):
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
13886716
...
@@ -7,12 +7,14 @@ import colossalai
...
@@ -7,12 +7,14 @@ import colossalai
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
common
import
CONFIG
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
,
Net
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
...
@@ -25,11 +27,11 @@ def run_dist(rank, world_size, port):
...
@@ -25,11 +27,11 @@ def run_dist(rank, world_size, port):
shard_param
=
True
):
shard_param
=
True
):
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'c
a
_attr'
)
assert
hasattr
(
param
,
'c
ol
_attr'
)
assert
param
.
c
a
_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
c
ol
_attr
.
data
.
dtype
==
torch
.
half
assert
param
.
c
a
_attr
.
_
data
_sharded_tensor
.
is_sharded
assert
param
.
c
ol
_attr
.
data
.
is_sharded
assert
param
.
c
a
_attr
.
data
.
device
.
type
==
'cuda'
assert
param
.
c
ol
_attr
.
data
.
payload
.
device
.
type
==
'cuda'
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
13886716
...
@@ -9,19 +9,21 @@ import pytest
...
@@ -9,19 +9,21 @@ import pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
Net
,
check_grads
,
check_grads_padding
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
x
,
enable_autocast
=
False
):
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
x
)
y
=
model
(
data
)
loss
=
y
.
sum
(
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
model
.
backward
(
loss
)
...
@@ -31,19 +33,26 @@ def run_fwd_bwd(model, x, enable_autocast=False):
...
@@ -31,19 +33,26 @@ def run_fwd_bwd(model, x, enable_autocast=False):
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
model
=
Net
(
checkpoint
=
True
).
cuda
()
for
model_name
in
test_models
:
zero_model
=
copy
.
deepcopy
(
model
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
zero_model
=
ShardedModelV2
(
zero_model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
for
_
in
range
(
2
):
model
=
model
().
half
().
cuda
()
x
=
torch
.
rand
(
2
,
5
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
run_fwd_bwd
(
zero_model
,
x
,
False
)
run_fwd_bwd
(
model
,
x
,
False
)
if
dist
.
get_world_size
()
>
1
:
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
)
model
=
DDP
(
model
)
else
:
check_grads
(
model
,
zero_model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_shard_param.py
View file @
13886716
...
@@ -4,18 +4,16 @@
...
@@ -4,18 +4,16 @@
from
copy
import
deepcopy
from
copy
import
deepcopy
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
import
colossalai
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedTensor
,
ShardedParam
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParam
,
ShardedTensor
from
tests.test_zero_data_parallel.common
import
Net
,
CONFIG
,
allclose
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
tests.test_zero_data_parallel.common
import
CONFIG
,
Net
,
allclose
def
_run_shard_tensor
(
rank
,
world_size
,
port
):
def
_run_shard_tensor
(
rank
,
world_size
,
port
):
...
@@ -47,7 +45,7 @@ def _run_shard_param_v2(rank, world_size, port):
...
@@ -47,7 +45,7 @@ def _run_shard_param_v2(rank, world_size, port):
param_ref
=
deepcopy
(
param
)
param_ref
=
deepcopy
(
param
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
sparam
=
ShardedParamV2
(
param
=
param
,
process_group
=
None
)
allclose
(
sparam
.
data
,
param_ref
.
data
)
allclose
(
sparam
.
data
.
payload
,
param_ref
.
data
)
sparam
.
remove_torch_payload
()
sparam
.
remove_torch_payload
()
assert
(
param
.
data
.
numel
()
==
1
)
assert
(
param
.
data
.
numel
()
==
1
)
...
...
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
0 → 100644
View file @
13886716
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
zero_model
=
zero_model
()
model
=
copy
.
deepcopy
(
zero_model
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model_state_dict
=
zero_model
.
state_dict
()
for
n
,
p
in
model
.
named_parameters
():
p
.
data
=
model_state_dict
[
n
]
model
=
model
.
half
().
cuda
()
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
def
test_shard_model_v2
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
()
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
13886716
...
@@ -56,7 +56,7 @@ def run_dist(rank, world_size, port):
...
@@ -56,7 +56,7 @@ def run_dist(rank, world_size, port):
check_params
(
model
,
zero_model
)
check_params
(
model
,
zero_model
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
def
test_sharded_optim_v2
():
def
test_sharded_optim_v2
():
world_size
=
2
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_zero_data_parallel/test_state_dict.py
0 → 100644
View file @
13886716
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
copy
import
deepcopy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
()
shard_strategy
=
TensorShardStrategy
()
model
=
model
.
half
().
cuda
()
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
zero_state_dict
=
zero_model
.
state_dict
()
for
key
,
val
in
model
.
state_dict
().
items
():
assert
torch
.
equal
(
val
,
zero_state_dict
[
key
])
@
pytest
.
mark
.
dist
def
test_zero_state_dict
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_state_dict
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment