Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8c90d4df
Unverified
Commit
8c90d4df
authored
Mar 29, 2022
by
HELSON
Committed by
GitHub
Mar 29, 2022
Browse files
[zero] add zero context manager to change config during initialization (#546)
parent
ec5086c4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
18 deletions
+185
-18
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+2
-0
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+13
-6
colossalai/zero/init_ctx/__init__.py
colossalai/zero/init_ctx/__init__.py
+2
-2
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+71
-10
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+97
-0
No files found.
colossalai/nn/layer/moe/experts.py
View file @
8c90d4df
...
@@ -5,6 +5,7 @@ import torch.nn as nn
...
@@ -5,6 +5,7 @@ import torch.nn as nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.zero.init_ctx
import
no_shard_zero_decrator
from
typing
import
Type
from
typing
import
Type
...
@@ -34,6 +35,7 @@ class Experts(MoeExperts):
...
@@ -34,6 +35,7 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
"""
@
no_shard_zero_decrator
def
__init__
(
self
,
expert_cls
:
Type
[
nn
.
Module
],
num_experts
:
int
,
**
expert_args
):
def
__init__
(
self
,
expert_cls
:
Type
[
nn
.
Module
],
num_experts
:
int
,
**
expert_args
):
super
().
__init__
(
"all_to_all"
,
num_experts
)
super
().
__init__
(
"all_to_all"
,
num_experts
)
...
...
colossalai/nn/layer/moe/layers.py
View file @
8c90d4df
import
functools
import
math
import
math
import
torch
import
torch
...
@@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
...
@@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
from
._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
.experts
import
MoeExperts
,
Experts
from
.experts
import
MoeExperts
,
Experts
from
.utils
import
ForceFP32Parameter
,
UniformNoiseGenerator
,
NormalNoiseGenerator
from
.utils
import
ForceFP32Parameter
,
UniformNoiseGenerator
,
NormalNoiseGenerator
from
colossalai.zero.init_ctx
import
no_shard_zero_context
,
no_shard_zero_decrator
from
typing
import
Callable
,
Optional
,
Type
from
typing
import
Callable
,
Optional
,
Type
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
...
@@ -205,7 +207,7 @@ class Top2Router(nn.Module):
...
@@ -205,7 +207,7 @@ class Top2Router(nn.Module):
return
cb_weight
,
sec_mask
return
cb_weight
,
sec_mask
class
FP32LinearGate
(
nn
.
Linear
):
class
FP32LinearGate
(
nn
.
Module
):
"""Gate module used in MOE layer. Just a linear function without bias.
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
But it should be kept as fp32 forever.
...
@@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
...
@@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
weight (ForceFP32Parameter): The weight of linear gate
weight (ForceFP32Parameter): The weight of linear gate
"""
"""
def
__init__
(
self
,
d_model
:
int
,
num_experts
:
int
):
def
__init__
(
self
,
d_model
:
int
,
num_experts
:
int
,
scale
:
float
=
0.1
):
super
().
__init__
(
d_model
,
num_experts
,
bias
=
False
,
device
=
get_current_device
())
super
().
__init__
()
self
.
weight
=
ForceFP32Parameter
(
self
.
weight
)
self
.
weight
=
ForceFP32Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
device
=
get_current_device
()))
nn
.
init
.
trunc_normal_
(
self
.
weight
,
std
=
math
.
sqrt
(
scale
/
d_model
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
F
.
linear
(
x
,
self
.
weight
)
class
MoeLayer
(
nn
.
Module
):
class
MoeLayer
(
nn
.
Module
):
...
@@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
...
@@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
"""
"""
@
no_shard_zero_decrator
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
nn
.
Module
,
experts
:
MoeExperts
):
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
nn
.
Module
,
experts
:
MoeExperts
):
super
().
__init__
()
super
().
__init__
()
self
.
d_model
=
dim_model
self
.
d_model
=
dim_model
...
@@ -361,7 +368,6 @@ class MoeModule(nn.Module):
...
@@ -361,7 +368,6 @@ class MoeModule(nn.Module):
min_capacity
=
min_capacity
,
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
)
drop_tks
=
drop_tks
)
self
.
use_residual
=
use_residual
self
.
use_residual
=
use_residual
if
use_residual
:
if
use_residual
:
if
residual_instance
is
not
None
:
if
residual_instance
is
not
None
:
...
@@ -371,7 +377,8 @@ class MoeModule(nn.Module):
...
@@ -371,7 +377,8 @@ class MoeModule(nn.Module):
"Expert class can't be None when residual instance is not given"
"Expert class can't be None when residual instance is not given"
self
.
residual_module
=
expert_cls
(
**
expert_args
)
self
.
residual_module
=
expert_cls
(
**
expert_args
)
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
with
no_shard_zero_context
():
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
if
expert_instance
is
not
None
:
if
expert_instance
is
not
None
:
self
.
experts
=
expert_instance
self
.
experts
=
expert_instance
...
...
colossalai/zero/init_ctx/__init__.py
View file @
8c90d4df
from
.init_context
import
ZeroInitContext
from
.init_context
import
ZeroInitContext
,
no_shard_zero_context
,
no_shard_zero_decrator
__all__
=
[
'ZeroInitContext'
]
__all__
=
[
'ZeroInitContext'
,
'no_shard_zero_context'
,
'no_shard_zero_decrator'
]
colossalai/zero/init_ctx/init_context.py
View file @
8c90d4df
import
contextlib
import
functools
import
functools
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
...
@@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
...
@@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
pass
pass
class
ZeroContextConfig
(
object
):
"""The configuration used to control zero context initialization.
Args:
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
"""
def
__init__
(
self
,
shard_param
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
):
super
().
__init__
()
self
.
shard_param
:
bool
=
shard_param
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
"""A context to initialize model.
"""A context to initialize model.
...
@@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
3. Shard the param and grad according to flags.
Args:
Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context.
target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
...
@@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super
().
__init__
()
super
().
__init__
()
self
.
target_device
=
target_device
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_strategy
=
shard_strategy
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
initialized_param_list
=
[]
self
.
initialized_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
config
=
ZeroContextConfig
(
shard_param
=
shard_param
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
)
ZeroContextMgr
().
current_context
=
self
@
property
def
shard_param
(
self
):
return
self
.
config
.
shard_param
@
property
def
rm_torch_payload_on_the_fly
(
self
):
return
self
.
config
.
rm_torch_payload_on_the_fly
def
_pre_context_exec
(
self
):
def
_pre_context_exec
(
self
):
"""
"""
The Callback function when entering the context
The Callback function when entering the context
...
@@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
NOTE() The module may be passed to this function multiple times.
"""
"""
def
half_fn
(
t
:
torch
.
Tensor
):
return
t
.
half
()
if
t
.
is_floating_point
()
else
t
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'col_attr'
):
if
hasattr
(
param
,
'col_attr'
):
...
@@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
model_numel_tensor
+=
param
.
numel
()
self
.
model_numel_tensor
+=
param
.
numel
()
target_device
=
self
.
target_device
# convert parameters to half
param_half
=
half_fn
(
param
)
# convert to fp16
param
.
data
=
param_half
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
)
grad_half
=
half_fn
(
param
.
grad
)
param
.
grad
.
data
=
grad_half
# move torch parameters to the target device
# move torch parameters to the target device
target_device
=
self
.
target_device
param
.
data
=
param
.
data
.
to
(
target_device
)
param
.
data
=
param
.
data
.
to
(
target_device
)
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized_param_list
.
append
(
param
)
# We must cast buffers
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# If we use BN, buffers may be on CPU and Float
...
@@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for
buffer
in
module
.
buffers
(
recurse
=
False
):
for
buffer
in
module
.
buffers
(
recurse
=
False
):
buffer
.
data
=
buffer
.
data
.
to
(
device
=
torch
.
cuda
.
current_device
())
buffer
.
data
=
buffer
.
data
.
to
(
device
=
torch
.
cuda
.
current_device
())
buffer
.
data
=
cast_tensor_to_fp16
(
buffer
.
data
)
buffer
.
data
=
cast_tensor_to_fp16
(
buffer
.
data
)
class
ZeroContextMgr
(
metaclass
=
SingletonMeta
):
current_context
:
Optional
[
ZeroInitContext
]
=
None
@
contextlib
.
contextmanager
def
hijack_context_config
(
self
,
**
kwargs
):
if
self
.
current_context
is
None
:
yield
else
:
old_config
=
self
.
current_context
.
config
self
.
current_context
.
config
=
ZeroContextConfig
(
**
kwargs
)
yield
self
.
current_context
.
config
=
old_config
def
no_shard_zero_context
():
return
ZeroContextMgr
().
hijack_context_config
(
shard_param
=
False
,
rm_torch_payload_on_the_fly
=
False
)
def
no_shard_zero_decrator
(
init_func
):
def
_no_shard
(
*
args
,
**
kwargs
):
with
no_shard_zero_context
():
init_func
(
*
args
,
**
kwargs
)
return
_no_shard
tests/test_moe/test_moe_zero_init.py
0 → 100644
View file @
8c90d4df
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.logging
import
get_dist_logger
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.context
import
MOE_CONTEXT
from
colossalai.nn.layer
import
MoeModule
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.testing
import
rerun_on_exception
from
colossalai.utils
import
get_current_device
from
tests.test_zero_data_parallel.common
import
CONFIG
class
MoeModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
expert_cls
=
nn
.
Linear
expert_args_dict
=
dict
(
in_features
=
8
,
out_features
=
8
)
self
.
moe
=
MoeModule
(
dim_model
=
8
,
num_experts
=
8
,
noisy_policy
=
'Jitter'
,
use_residual
=
True
,
expert_cls
=
expert_cls
,
**
expert_args_dict
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
x
=
self
.
moe
(
x
)
x
=
self
.
proj2
(
x
)
return
x
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_moe_zero_init
(
init_device_type
,
shard_strategy_class
):
logger
=
get_dist_logger
(
"test_moe_zero_init"
)
if
init_device_type
==
'cuda'
:
init_device
=
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
elif
init_device_type
==
'cpu'
:
init_device
=
torch
.
device
(
"cpu"
)
else
:
raise
NotImplementedError
(
"Unknown device found."
)
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
,
rm_torch_payload_on_the_fly
=
False
):
model
=
MoeModel
()
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'col_attr'
)
# the weights in the gate should be fp32
if
'gate'
in
name
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
def
_run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
MOE_CONTEXT
.
setup
(
seed
=
42
)
run_moe_zero_init
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_zero_init
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_moe_zero_init
(
world_size
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment