Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
31c64402
Unverified
Commit
31c64402
authored
Nov 30, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 30, 2022
Browse files
[hotfix] hotfix Gemini for no leaf modules bug (#2043)
parent
384cd263
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
28 deletions
+82
-28
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+68
-22
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+14
-6
No files found.
colossalai/utils/model/colo_init_context.py
View file @
31c64402
from
typing
import
Dict
,
Iterator
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterator
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
colossalai.nn.parallel.layers
import
ColoEmbedding
,
ColoLinear
,
register_colo_module
from
colossalai.nn.parallel.layers
import
ColoEmbedding
,
ColoLinear
,
register_colo_module
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
,
ProcessGroup
,
ShardSpec
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
,
ProcessGroup
from
.utils
import
InsertPostInitMethodToModuleSubClasses
from
.utils
import
InsertPostInitMethodToModuleSubClasses
...
@@ -26,6 +26,34 @@ def _named_params_with_replica(
...
@@ -26,6 +26,34 @@ def _named_params_with_replica(
yield
name
,
val
yield
name
,
val
def
_convert_to_coloparam
(
param
:
torch
.
nn
.
Parameter
,
device
:
torch
.
device
,
dtype
=
torch
.
float
,
default_pg
:
Optional
[
ProcessGroup
]
=
None
,
default_dist_spec
:
Optional
[
Any
]
=
None
)
->
ColoParameter
:
if
isinstance
(
param
,
ColoParameter
):
return
param
# detaching tensor is necessary for optimizers.
requires_grad
=
param
.
requires_grad
# param is the global tensor.
colo_param
=
ColoParameter
(
param
.
to
(
device
=
device
,
dtype
=
dtype
),
requires_grad
=
requires_grad
)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if
default_pg
is
not
None
:
colo_param
.
set_process_group
(
default_pg
)
if
default_dist_spec
is
not
None
:
try
:
colo_param
.
set_dist_spec
(
default_dist_spec
)
except
:
pass
return
colo_param
def
ColoModulize
(
module
):
def
ColoModulize
(
module
):
"""
"""
Replacing the parameters() and named_parameters() with our customized ones
Replacing the parameters() and named_parameters() with our customized ones
...
@@ -94,26 +122,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -94,26 +122,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
in
replaced_tensors
:
if
param
in
replaced_tensors
:
colo_param
=
replaced_tensors
[
param
]
colo_param
=
replaced_tensors
[
param
]
else
:
else
:
# detaching tensor is necessary for optimizers.
colo_param
=
_convert_to_coloparam
(
param
,
self
.
_device
,
self
.
_dtype
,
self
.
_default_pg
,
requires_grad
=
param
.
requires_grad
self
.
_default_dist_spec
)
# param is the global tensor.
colo_param
=
ColoParameter
(
param
.
to
(
device
=
self
.
_device
,
dtype
=
self
.
_dtype
),
requires_grad
=
requires_grad
)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if
self
.
_default_pg
is
not
None
:
colo_param
.
set_process_group
(
self
.
_default_pg
)
if
self
.
_default_dist_spec
is
not
None
:
try
:
colo_param
.
set_dist_spec
(
self
.
_default_dist_spec
)
except
:
pass
replaced_tensors
[
param
]
=
colo_param
replaced_tensors
[
param
]
=
colo_param
delattr
(
submodule
,
param_name
)
delattr
(
submodule
,
param_name
)
setattr
(
submodule
,
param_name
,
colo_param
)
setattr
(
submodule
,
param_name
,
colo_param
)
...
@@ -121,3 +131,39 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -121,3 +131,39 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
module
.
to
(
self
.
_device
)
module
.
to
(
self
.
_device
)
ColoModulize
(
module
)
ColoModulize
(
module
)
def
post_process_colo_init_ctx
(
model
:
torch
.
nn
.
Module
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
),
dtype
:
torch
.
dtype
=
torch
.
float
,
default_pg
:
Optional
[
ProcessGroup
]
=
None
,
default_dist_spec
=
None
):
"""post_process_colo_init_ctx
This function is called after `ColoInitContext`.
Args:
model (torch.nn.module): the model
device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu').
dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float.
default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Inidicates a DP-only process group.
default_dist_spec (Any, optional): default dist spec of params. Defaults to None.
Raises:
RuntimeError: raise error if
"""
torch_params
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoParameter
):
print
(
f
"
{
n
}
is not a ColoParameter. We are going to converting it to ColoParameter"
)
torch_params
.
append
((
n
,
p
))
for
(
n
,
param
)
in
torch_params
:
delattr
(
model
,
n
)
setattr
(
model
,
n
,
_convert_to_coloparam
(
param
,
device
,
dtype
,
default_pg
,
default_dist_spec
))
del
torch_params
for
n
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
raise
RuntimeError
tests/test_gemini/update/test_optim.py
View file @
31c64402
...
@@ -15,10 +15,11 @@ from colossalai.gemini.gemini_mgr import GeminiManager
...
@@ -15,10 +15,11 @@ from colossalai.gemini.gemini_mgr import GeminiManager
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
,
post_process_colo_init_ctx
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
from
tests.test_tensor.common_utils
import
debug_print
,
set_seed
...
@@ -40,8 +41,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -40,8 +41,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
# 'gpt2', 'bert',
# 'gpt2', 'bert',
TEST_MODELS
=
[
'gpt2'
,
'bert'
]
TEST_MODELS
=
[
'no_leaf_module'
,
'gpt2'
,
'bert'
,
'simple_net'
,
'nested_model'
,
'repeated_computed_layers'
]
EXAMPLE_MODELS
=
[
'simple_net'
]
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
...
@@ -57,8 +57,12 @@ def exam_model_step(placement_policy, model_name: str):
...
@@ -57,8 +57,12 @@ def exam_model_step(placement_policy, model_name: str):
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
with
ColoInitContext
(
device
=
get_current_device
()):
init_dev
=
get_current_device
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
()
model
=
model_builder
()
post_process_colo_init_ctx
(
model
,
device
=
init_dev
)
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
...
@@ -99,7 +103,7 @@ def exam_model_step(placement_policy, model_name: str):
...
@@ -99,7 +103,7 @@ def exam_model_step(placement_policy, model_name: str):
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'model_name'
,
EXAMPLE
_MODELS
)
@
parameterize
(
'model_name'
,
TEST
_MODELS
)
def
exam_tiny_example
(
placement_policy
,
model_name
:
str
):
def
exam_tiny_example
(
placement_policy
,
model_name
:
str
):
set_seed
(
2008
)
set_seed
(
2008
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
...
@@ -111,8 +115,12 @@ def exam_tiny_example(placement_policy, model_name: str):
...
@@ -111,8 +115,12 @@ def exam_tiny_example(placement_policy, model_name: str):
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
dist
.
get_rank
()])
with
ColoInitContext
(
device
=
get_current_device
()):
init_dev
=
get_current_device
()
with
ColoInitContext
(
device
=
init_dev
):
model
=
model_builder
()
model
=
model_builder
()
post_process_colo_init_ctx
(
model
,
device
=
init_dev
)
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
p
.
data
.
copy_
(
torch_p
.
data
)
p
.
data
.
copy_
(
torch_p
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment