Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
055fbf5b
"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "e80ebf6bc44df480dc4a6ea21694c6653d0936fb"
Unverified
Commit
055fbf5b
authored
Apr 01, 2022
by
HELSON
Committed by
GitHub
Apr 01, 2022
Browse files
[zero] adapt zero for unsharded paramters (Optimizer part) (#601)
parent
229382c8
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
208 additions
and
44 deletions
+208
-44
colossalai/utils/checkpointing.py
colossalai/utils/checkpointing.py
+4
-1
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+19
-6
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+30
-22
tests/components_to_test/no_leaf_module.py
tests/components_to_test/no_leaf_module.py
+2
-1
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+5
-2
tests/test_moe/test_moe_zero_model.py
tests/test_moe/test_moe_zero_model.py
+1
-1
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+134
-0
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+13
-11
No files found.
colossalai/utils/checkpointing.py
View file @
055fbf5b
...
@@ -6,7 +6,10 @@ import torch.distributed as dist
...
@@ -6,7 +6,10 @@ import torch.distributed as dist
from
colossalai.communication.collective
import
scatter_object_list
from
colossalai.communication.collective
import
scatter_object_list
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
except
ImportError
:
_EXTRA_STATE_KEY_SUFFIX
=
'_extra_state'
from
.common
import
is_using_pp
from
.common
import
is_using_pp
...
...
colossalai/zero/init_ctx/init_context.py
View file @
055fbf5b
...
@@ -11,6 +11,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
...
@@ -11,6 +11,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
contextlib
import
AbstractContextManager
def
_substitute_init_recursively
(
cls
,
func
):
def
_substitute_init_recursively
(
cls
,
func
):
...
@@ -88,6 +89,7 @@ class ZeroContextConfig(object):
...
@@ -88,6 +89,7 @@ class ZeroContextConfig(object):
"""The configuration used to control zero context initialization.
"""The configuration used to control zero context initialization.
Args:
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
...
@@ -99,8 +101,13 @@ class ZeroContextConfig(object):
...
@@ -99,8 +101,13 @@ class ZeroContextConfig(object):
See torchvision resnet18. Defaults to False.
See torchvision resnet18. Defaults to False.
"""
"""
def
__init__
(
self
,
replicated
:
bool
=
True
,
shard_param
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
):
def
__init__
(
self
,
target_device
:
torch
.
device
,
replicated
:
bool
=
True
,
shard_param
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
target_device
=
target_device
self
.
is_replicated
:
bool
=
replicated
self
.
is_replicated
:
bool
=
replicated
self
.
shard_param
:
bool
=
shard_param
self
.
shard_param
:
bool
=
shard_param
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
...
@@ -114,7 +121,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -114,7 +121,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
3. Shard the param and grad according to flags.
Args:
Args:
target_device (torch.device): The device where param data after exiting the context.
target_device (torch.device): The device where param data
are
after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
...
@@ -136,17 +143,22 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -136,17 +143,22 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
target_device
=
target_device
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
initialized_param_list
=
[]
self
.
initialized_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
config
=
ZeroContextConfig
(
replicated
=
True
,
self
.
config
=
ZeroContextConfig
(
target_device
=
target_device
,
replicated
=
True
,
shard_param
=
shard_param
,
shard_param
=
shard_param
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
)
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
)
ZeroContextMgr
().
current_context
=
self
ZeroContextMgr
().
current_context
=
self
@
property
def
target_device
(
self
):
return
self
.
config
.
target_device
@
property
@
property
def
is_replicated
(
self
):
def
is_replicated
(
self
):
return
self
.
config
.
is_replicated
return
self
.
config
.
is_replicated
...
@@ -235,8 +247,9 @@ class ZeroContextMgr(metaclass=SingletonMeta):
...
@@ -235,8 +247,9 @@ class ZeroContextMgr(metaclass=SingletonMeta):
self
.
current_context
.
config
=
old_config
self
.
current_context
.
config
=
old_config
def
no_shard_zero_context
(
is_replicated
:
bool
=
True
):
def
no_shard_zero_context
(
is_replicated
:
bool
=
True
)
->
AbstractContextManager
:
return
ZeroContextMgr
().
hijack_context_config
(
replicated
=
is_replicated
,
return
ZeroContextMgr
().
hijack_context_config
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
replicated
=
is_replicated
,
shard_param
=
False
,
shard_param
=
False
,
rm_torch_payload_on_the_fly
=
False
)
rm_torch_payload_on_the_fly
=
False
)
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
055fbf5b
...
@@ -12,13 +12,12 @@ from colossalai.logging import get_dist_logger
...
@@ -12,13 +12,12 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard_utils.tensor_utils
import
(
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.shard_utils.tensor_utils
import
(
colo_model_data_tensor_move_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
colossalai.zero.sharded_param.tensorful_state
import
(
StatefulTensor
,
TensorState
)
from
colossalai.zero.shard_utils.tensor_utils
import
colo_model_data_tensor_move_inline
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -69,6 +68,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -69,6 +68,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
In Zero-2, set keep_unsharded to False.
In Zero-3, set keep_unsharded to True.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
...
@@ -89,6 +91,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -89,6 +91,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval
:
float
=
1000
,
growth_interval
:
float
=
1000
,
hysteresis
:
float
=
2
,
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
max_scale
:
int
=
2
**
32
,
keep_unsharded
:
bool
=
False
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
...
@@ -122,24 +125,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -122,24 +125,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
torch
.
cuda
.
current_device
())
self
.
_found_overflow
:
Tensor
=
torch
.
FloatTensor
([
0
]).
to
(
torch
.
cuda
.
current_device
())
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
# Store fp32 param shards
assert
not
(
keep_unsharded
and
self
.
_should_move_fp32_shards_h2d
),
\
self
.
master_params
:
Dict
[
Parameter
,
StatefulTensor
]
=
{}
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
self
.
keep_unshard
=
keep_unsharded
for
group
in
self
.
optim
.
param_groups
:
# Store fp32 param shards
for
p
in
group
[
'params'
]:
self
.
_register_master_weight
()
assert
hasattr
(
p
,
'colo_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
StatefulTensor
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
))
if
not
is_param_sharded
:
# In this branch, there's no need to shard param
# So we gather here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ranks
=
[
0
])
...
@@ -283,6 +274,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -283,6 +274,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def
sync_grad
(
self
):
def
sync_grad
(
self
):
pass
pass
def
_register_master_weight
(
self
):
self
.
master_params
:
Dict
[
Parameter
,
StatefulTensor
]
=
{}
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'colo_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
and
not
self
.
keep_unshard
:
# Please use keep_unsharded to control whether shard unsharded paramters
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
StatefulTensor
(
cast_tensor_to_fp32
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
))
if
not
is_param_sharded
and
not
self
.
keep_unshard
:
# In this branch, there's no need to shard param
# So we gather here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
def
_maybe_move_fp32_shards
(
self
):
def
_maybe_move_fp32_shards
(
self
):
if
self
.
_should_move_fp32_shards_h2d
:
if
self
.
_should_move_fp32_shards_h2d
:
self
.
_should_move_fp32_shards_h2d
=
False
self
.
_should_move_fp32_shards_h2d
=
False
...
@@ -328,7 +336,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -328,7 +336,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
is_param_sharded
=
p
.
colo_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
if
not
is_param_sharded
and
not
self
.
keep_unshard
:
# We use ZeRO-2 here
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# But we only have updated fp32 param shard here
...
@@ -342,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -342,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
p
.
colo_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
if
not
is_param_sharded
and
not
self
.
keep_unshard
:
# We gather full fp16 param here
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
gather
([
p
.
colo_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
p
.
data
=
p
.
colo_attr
.
sharded_data_tensor
.
payload
...
...
tests/components_to_test/no_leaf_module.py
View file @
055fbf5b
...
@@ -42,4 +42,5 @@ def get_training_components():
...
@@ -42,4 +42,5 @@ def get_training_components():
testloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
from
colossalai.nn.optimizer
import
HybridAdam
return
model_builder
,
trainloader
,
testloader
,
HybridAdam
,
criterion
tests/test_moe/test_moe_zero_init.py
View file @
055fbf5b
...
@@ -76,8 +76,11 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
...
@@ -76,8 +76,11 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
else
:
else
:
assert
param
.
is_replicated
assert
param
.
is_replicated
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
if
param
.
colo_attr
.
param_is_sharded
:
f
'
{
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
else
:
assert
param
.
colo_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
'cuda'
def
_run_dist
(
rank
,
world_size
,
port
):
def
_run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_moe/test_moe_zero_model.py
View file @
055fbf5b
...
@@ -67,7 +67,7 @@ def run_dist(rank, world_size, port):
...
@@ -67,7 +67,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_zero_model
(
world_size
):
def
test_moe_zero_model
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
tests/test_moe/test_moe_zero_optim.py
0 → 100644
View file @
055fbf5b
from
functools
import
partial
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.testing
import
parameterize
,
rerun_on_exception
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.utils
import
get_current_device
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.engine.gradient_handler
import
MoeGradientHandler
from
colossalai.context
import
MOE_CONTEXT
from
colossalai.testing
import
assert_equal_in_group
from
tests.test_zero_data_parallel.common
import
CONFIG
,
check_sharded_model_params
from
tests.test_moe.test_moe_zero_init
import
MoeModel
def
_run_step
(
model
,
optimizer
,
data
,
label
,
criterion
,
grad_handler
):
model
.
train
()
optimizer
.
zero_grad
()
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
if
grad_handler
is
not
None
:
grad_handler
.
handle_gradient
()
optimizer
.
step
()
@
parameterize
(
"cpu_offload"
,
[
True
,
False
])
@
parameterize
(
"use_cpuadam"
,
[
True
,
False
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
_run_test_sharded_optim_v2
(
cpu_offload
,
shard_strategy_class
,
use_cpuadam
,
gpu_margin_mem_ratio
=
0.0
):
MOE_CONTEXT
.
reset_loss
()
shard_strategy
=
shard_strategy_class
()
if
use_cpuadam
and
cpu_offload
is
False
:
return
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'no_leaf_module'
)
_
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cpu'
)
if
cpu_offload
else
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
zero_model
=
MoeModel
()
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
offload_config
=
dict
(
device
=
'cpu'
)
if
cpu_offload
else
None
,
use_memory_tracer
=
gpu_margin_mem_ratio
>
0.0
,
reuse_fp16_shard
=
use_cpuadam
,
)
# check whether parameters are identical in ddp
for
name
,
p
in
zero_model
.
named_parameters
():
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
is_replicated
:
assert_equal_in_group
(
p
.
data
.
to
(
get_current_device
()))
model
=
MoeModel
().
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
().
float
()
if
use_cpuadam
:
optimizer_class
=
CPUAdam
optim
=
optimizer_class
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
optimizer_class
(
zero_model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
sharded_optim
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
,
gpu_margin_mem_ratio
=
gpu_margin_mem_ratio
,
keep_unsharded
=
True
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
)
apex_model
,
apex_optimizer
=
convert_to_apex_amp
(
model
,
optim
,
amp_config
)
apex_grad_handler
=
MoeGradientHandler
(
model
)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for
(
n
,
p
),
zp
in
zip
(
apex_model
.
named_parameters
(),
zero_model
.
parameters
()):
if
'gate'
in
n
:
p
.
data
=
p
.
float
()
p
.
data
.
copy_
(
zp
.
data
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
5
:
break
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
_run_step
(
apex_model
,
apex_optimizer
,
data
,
label
,
criterion
,
apex_grad_handler
)
_run_step
(
zero_model
,
sharded_optim
,
data
,
label
,
criterion
,
None
)
check_sharded_model_params
(
model
,
zero_model
,
loose
=
True
,
reuse_fp16_shard
=
use_cpuadam
)
for
param
in
model
.
parameters
():
assert
not
has_inf_or_nan
(
param
)
def
_run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
MOE_CONTEXT
.
setup
(
seed
=
42
)
_run_test_sharded_optim_v2
()
# use_cpuadam = True can be used with cpu_offload = False
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_zero_optim
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_moe_zero_optim
(
world_size
=
2
)
tests/test_zero_data_parallel/common.py
View file @
055fbf5b
...
@@ -124,16 +124,18 @@ def check_params_padding(model, zero_model, loose=False):
...
@@ -124,16 +124,18 @@ def check_params_padding(model, zero_model, loose=False):
def
check_sharded_model_params
(
model
,
zero_model
,
loose
=
False
,
reuse_fp16_shard
=
False
):
def
check_sharded_model_params
(
model
,
zero_model
,
loose
=
False
,
reuse_fp16_shard
=
False
):
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
for
(
name
,
p
),
(
zero_name
,
zero_p
)
in
zip
(
model
.
named_parameters
(),
zero_model
.
named_parameters
()):
if
reuse_fp16_shard
:
if
zero_p
.
colo_attr
.
param_is_sharded
:
zero_p
=
zero_p
.
data
.
to
(
p
.
device
).
float
()
if
reuse_fp16_shard
:
else
:
zero_p
=
zero_p
.
data
.
to
(
p
.
device
).
float
()
zero_p
=
zero_p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
else
:
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
zero_p
=
zero_p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
if
rank
>=
len
(
chunks
):
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
continue
if
rank
>=
len
(
chunks
):
p
=
chunks
[
rank
].
float
()
continue
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
p
=
chunks
[
rank
].
float
()
zero_p
=
zero_p
[:
p
.
size
(
0
)]
if
zero_p
.
size
(
0
)
>
p
.
size
(
0
):
zero_p
=
zero_p
[:
p
.
size
(
0
)]
assert
p
.
dtype
==
zero_p
.
dtype
assert
p
.
dtype
==
zero_p
.
dtype
assert
allclose
(
p
,
zero_p
,
loose
=
loose
),
f
'
{
p
}
vs
{
zero_p
}
'
assert
allclose
(
p
,
zero_p
,
loose
=
loose
),
f
'
{
p
}
vs
{
zero_p
}
'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment