Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8c90d4df
Unverified
Commit
8c90d4df
authored
Mar 29, 2022
by
HELSON
Committed by
GitHub
Mar 29, 2022
Browse files
[zero] add zero context manager to change config during initialization (#546)
parent
ec5086c4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
18 deletions
+185
-18
colossalai/nn/layer/moe/experts.py
colossalai/nn/layer/moe/experts.py
+2
-0
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+13
-6
colossalai/zero/init_ctx/__init__.py
colossalai/zero/init_ctx/__init__.py
+2
-2
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+71
-10
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+97
-0
No files found.
colossalai/nn/layer/moe/experts.py
View file @
8c90d4df
...
...
@@ -5,6 +5,7 @@ import torch.nn as nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.utils
import
get_current_device
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.zero.init_ctx
import
no_shard_zero_decrator
from
typing
import
Type
...
...
@@ -34,6 +35,7 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
@
no_shard_zero_decrator
def
__init__
(
self
,
expert_cls
:
Type
[
nn
.
Module
],
num_experts
:
int
,
**
expert_args
):
super
().
__init__
(
"all_to_all"
,
num_experts
)
...
...
colossalai/nn/layer/moe/layers.py
View file @
8c90d4df
import
functools
import
math
import
torch
...
...
@@ -9,6 +10,7 @@ from colossalai.utils import get_current_device
from
._operation
import
COL_MOE_KERNEL_FLAG
,
AllToAll
,
AllGather
,
ReduceScatter
,
MoeDispatch
,
MoeCombine
,
moe_cumsum
from
.experts
import
MoeExperts
,
Experts
from
.utils
import
ForceFP32Parameter
,
UniformNoiseGenerator
,
NormalNoiseGenerator
from
colossalai.zero.init_ctx
import
no_shard_zero_context
,
no_shard_zero_decrator
from
typing
import
Callable
,
Optional
,
Type
from
torch.distributed
import
ProcessGroup
...
...
@@ -205,7 +207,7 @@ class Top2Router(nn.Module):
return
cb_weight
,
sec_mask
class
FP32LinearGate
(
nn
.
Linear
):
class
FP32LinearGate
(
nn
.
Module
):
"""Gate module used in MOE layer. Just a linear function without bias.
But it should be kept as fp32 forever.
...
...
@@ -217,9 +219,13 @@ class FP32LinearGate(nn.Linear):
weight (ForceFP32Parameter): The weight of linear gate
"""
def
__init__
(
self
,
d_model
:
int
,
num_experts
:
int
):
super
().
__init__
(
d_model
,
num_experts
,
bias
=
False
,
device
=
get_current_device
())
self
.
weight
=
ForceFP32Parameter
(
self
.
weight
)
def
__init__
(
self
,
d_model
:
int
,
num_experts
:
int
,
scale
:
float
=
0.1
):
super
().
__init__
()
self
.
weight
=
ForceFP32Parameter
(
torch
.
empty
(
num_experts
,
d_model
,
device
=
get_current_device
()))
nn
.
init
.
trunc_normal_
(
self
.
weight
,
std
=
math
.
sqrt
(
scale
/
d_model
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
F
.
linear
(
x
,
self
.
weight
)
class
MoeLayer
(
nn
.
Module
):
...
...
@@ -235,6 +241,7 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
"""
@
no_shard_zero_decrator
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
nn
.
Module
,
experts
:
MoeExperts
):
super
().
__init__
()
self
.
d_model
=
dim_model
...
...
@@ -361,7 +368,6 @@ class MoeModule(nn.Module):
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
)
self
.
use_residual
=
use_residual
if
use_residual
:
if
residual_instance
is
not
None
:
...
...
@@ -371,7 +377,8 @@ class MoeModule(nn.Module):
"Expert class can't be None when residual instance is not given"
self
.
residual_module
=
expert_cls
(
**
expert_args
)
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
with
no_shard_zero_context
():
self
.
residual_combine
=
nn
.
Linear
(
dim_model
,
2
,
device
=
get_current_device
())
if
expert_instance
is
not
None
:
self
.
experts
=
expert_instance
...
...
colossalai/zero/init_ctx/__init__.py
View file @
8c90d4df
from
.init_context
import
ZeroInitContext
from
.init_context
import
ZeroInitContext
,
no_shard_zero_context
,
no_shard_zero_decrator
__all__
=
[
'ZeroInitContext'
]
__all__
=
[
'ZeroInitContext'
,
'no_shard_zero_context'
,
'no_shard_zero_decrator'
]
colossalai/zero/init_ctx/init_context.py
View file @
8c90d4df
import
contextlib
import
functools
from
typing
import
Optional
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
...
...
@@ -82,6 +84,25 @@ class InsertPostInitMethodToModuleSubClasses(object):
pass
class
ZeroContextConfig
(
object
):
"""The configuration used to control zero context initialization.
Args:
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
If set to `False`, remove tensor payload on param.data afther the context exist.
This is used when you add some logic to operate tensors in __init__ of module.
See torchvision resnet18. Defaults to False.
"""
def
__init__
(
self
,
shard_param
:
bool
=
False
,
rm_torch_payload_on_the_fly
:
bool
=
False
):
super
().
__init__
()
self
.
shard_param
:
bool
=
shard_param
self
.
rm_torch_payload_on_the_fly
:
bool
=
rm_torch_payload_on_the_fly
class
ZeroInitContext
(
InsertPostInitMethodToModuleSubClasses
):
"""A context to initialize model.
...
...
@@ -90,11 +111,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
Args:
convert_fp16 (bool): Whether to convert params to fp16.
target_device (torch.device): The device where param data after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
This will reduce memory usage when initializing model.
But it's not suitable for all models, especially when there are `weight init` operations in `__init__`.
...
...
@@ -115,13 +134,23 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
super
().
__init__
()
self
.
target_device
=
target_device
self
.
shard_param
=
shard_param
self
.
shard_strategy
=
shard_strategy
self
.
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
self
.
initialized_param_list
=
[]
self
.
model_numel_tensor
=
model_numel_tensor
self
.
dp_process_group
=
dp_process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
config
=
ZeroContextConfig
(
shard_param
=
shard_param
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
)
ZeroContextMgr
().
current_context
=
self
@
property
def
shard_param
(
self
):
return
self
.
config
.
shard_param
@
property
def
rm_torch_payload_on_the_fly
(
self
):
return
self
.
config
.
rm_torch_payload_on_the_fly
def
_pre_context_exec
(
self
):
"""
The Callback function when entering the context
...
...
@@ -143,6 +172,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
def
half_fn
(
t
:
torch
.
Tensor
):
return
t
.
half
()
if
t
.
is_floating_point
()
else
t
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'col_attr'
):
...
...
@@ -150,23 +183,24 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
model_numel_tensor
+=
param
.
numel
()
target_device
=
self
.
target_device
# convert to fp16
param
.
data
=
param
.
data
.
to
(
torch
.
half
)
# convert parameters to half
param_half
=
half_fn
(
param
)
param
.
data
=
param_half
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
torch
.
half
)
grad_half
=
half_fn
(
param
.
grad
)
param
.
grad
.
data
=
grad_half
# move torch parameters to the target device
target_device
=
self
.
target_device
param
.
data
=
param
.
data
.
to
(
target_device
)
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
self
.
initialized_param_list
.
append
(
param
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized_param_list
.
append
(
param
)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
...
...
@@ -174,3 +208,30 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for
buffer
in
module
.
buffers
(
recurse
=
False
):
buffer
.
data
=
buffer
.
data
.
to
(
device
=
torch
.
cuda
.
current_device
())
buffer
.
data
=
cast_tensor_to_fp16
(
buffer
.
data
)
class
ZeroContextMgr
(
metaclass
=
SingletonMeta
):
current_context
:
Optional
[
ZeroInitContext
]
=
None
@
contextlib
.
contextmanager
def
hijack_context_config
(
self
,
**
kwargs
):
if
self
.
current_context
is
None
:
yield
else
:
old_config
=
self
.
current_context
.
config
self
.
current_context
.
config
=
ZeroContextConfig
(
**
kwargs
)
yield
self
.
current_context
.
config
=
old_config
def
no_shard_zero_context
():
return
ZeroContextMgr
().
hijack_context_config
(
shard_param
=
False
,
rm_torch_payload_on_the_fly
=
False
)
def
no_shard_zero_decrator
(
init_func
):
def
_no_shard
(
*
args
,
**
kwargs
):
with
no_shard_zero_context
():
init_func
(
*
args
,
**
kwargs
)
return
_no_shard
tests/test_moe/test_moe_zero_init.py
0 → 100644
View file @
8c90d4df
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.logging
import
get_dist_logger
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.context
import
MOE_CONTEXT
from
colossalai.nn.layer
import
MoeModule
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.testing
import
rerun_on_exception
from
colossalai.utils
import
get_current_device
from
tests.test_zero_data_parallel.common
import
CONFIG
class
MoeModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
expert_cls
=
nn
.
Linear
expert_args_dict
=
dict
(
in_features
=
8
,
out_features
=
8
)
self
.
moe
=
MoeModule
(
dim_model
=
8
,
num_experts
=
8
,
noisy_policy
=
'Jitter'
,
use_residual
=
True
,
expert_cls
=
expert_cls
,
**
expert_args_dict
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
x
=
self
.
moe
(
x
)
x
=
self
.
proj2
(
x
)
return
x
@
parameterize
(
"init_device_type"
,
[
'cpu'
,
'cuda'
])
@
parameterize
(
"shard_strategy_class"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_moe_zero_init
(
init_device_type
,
shard_strategy_class
):
logger
=
get_dist_logger
(
"test_moe_zero_init"
)
if
init_device_type
==
'cuda'
:
init_device
=
torch
.
device
(
f
"cuda:
{
get_current_device
()
}
"
)
elif
init_device_type
==
'cpu'
:
init_device
=
torch
.
device
(
"cpu"
)
else
:
raise
NotImplementedError
(
"Unknown device found."
)
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
target_device
=
init_device
,
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
,
rm_torch_payload_on_the_fly
=
False
):
model
=
MoeModel
()
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'col_attr'
)
# the weights in the gate should be fp32
if
'gate'
in
name
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
def
_run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
MOE_CONTEXT
.
setup
(
seed
=
42
)
run_moe_zero_init
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_moe_zero_init
(
world_size
):
run_func
=
partial
(
_run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_moe_zero_init
(
world_size
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment