Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
eb1b8990
Unverified
Commit
eb1b8990
authored
Apr 21, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 21, 2022
Browse files
[refactor] moving InsertPostInitMethodToModuleSubClasses to utils. (#824)
parent
2ecc3d7a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
77 deletions
+85
-77
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+3
-1
colossalai/utils/model/init_context.py
colossalai/utils/model/init_context.py
+81
-0
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+1
-76
No files found.
colossalai/utils/__init__.py
View file @
eb1b8990
...
@@ -11,6 +11,7 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
...
@@ -11,6 +11,7 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
colo_device_memory_capacity
,
colo_set_cpu_memory_capacity
,
colo_get_cpu_memory_capacity
)
colo_device_memory_capacity
,
colo_set_cpu_memory_capacity
,
colo_get_cpu_memory_capacity
)
from
.timer
import
MultiTimer
,
Timer
from
.timer
import
MultiTimer
,
Timer
from
.tensor_detector
import
TensorDetector
from
.tensor_detector
import
TensorDetector
from
.model.init_context
import
InsertPostInitMethodToModuleSubClasses
__all__
=
[
__all__
=
[
'checkpoint'
,
'free_port'
,
'print_rank_0'
,
'sync_model_param'
,
'is_dp_rank_0'
,
'is_tp_rank_0'
,
'checkpoint'
,
'free_port'
,
'print_rank_0'
,
'sync_model_param'
,
'is_dp_rank_0'
,
'is_tp_rank_0'
,
...
@@ -20,5 +21,6 @@ __all__ = [
...
@@ -20,5 +21,6 @@ __all__ = [
'report_memory_usage'
,
'colo_device_memory_capacity'
,
'colo_device_memory_used'
,
'colo_set_process_memory_fraction'
,
'report_memory_usage'
,
'colo_device_memory_capacity'
,
'colo_device_memory_used'
,
'colo_set_process_memory_fraction'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'DataParallelSampler'
,
'get_dataloader'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'DataParallelSampler'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'ensure_path_exists'
,
'disposable'
,
'colo_set_cpu_memory_capacity'
,
'colo_get_cpu_memory_capacity'
'ensure_path_exists'
,
'disposable'
,
'colo_set_cpu_memory_capacity'
,
'colo_get_cpu_memory_capacity'
,
'InsertPostInitMethodToModuleSubClasses'
]
]
colossalai/utils/model/init_context.py
0 → 100644
View file @
eb1b8990
import
torch
import
functools
from
typing
import
Optional
def
_substitute_init_recursively
(
cls
,
func
):
for
subcls
in
cls
.
__subclasses__
():
_substitute_init_recursively
(
subcls
,
func
)
func
(
subcls
)
class
InsertPostInitMethodToModuleSubClasses
(
object
):
def
__init__
(
self
,
default_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
self
.
_old_default_dtype
=
None
self
.
_default_dtype
=
default_dtype
def
__enter__
(
self
):
r
"""
Enter the context scope.
"""
if
self
.
_default_dtype
is
not
None
:
self
.
_old_default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
self
.
_default_dtype
)
def
preprocess_after
(
f
):
@
functools
.
wraps
(
f
)
def
wrapper
(
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
f
(
module
,
*
args
,
**
kwargs
)
self
.
_post_init_method
(
module
)
return
wrapper
def
_enable_class
(
cls
):
cls
.
_old_init
=
cls
.
__init__
cls
.
__init__
=
preprocess_after
(
cls
.
__init__
)
# The function is called during init subclass.
def
_init_subclass
(
cls
,
**
kwargs
):
cls
.
__init__
=
preprocess_after
(
cls
.
__init__
)
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function.
_substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_enable_class
)
# holding on to the current __init__subclass__ for exit
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
=
(
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
)
# Replace .__init__() for future subclasses of torch.nn.Module
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
=
classmethod
(
_init_subclass
)
self
.
_pre_context_exec
()
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
_default_dtype
is
not
None
:
torch
.
set_default_dtype
(
self
.
_old_default_dtype
)
def
_disable_class
(
cls
):
cls
.
__init__
=
cls
.
_old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module
_substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_disable_class
)
# Replace .__init__() for future subclasses of torch.nn.Module
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
=
(
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
)
self
.
_post_context_exec
()
# Now that we cleaned up the metaclass injection, raise the exception.
if
exc_type
is
not
None
:
return
False
# To be implemented by inheriting classes
def
_post_init_method
(
self
,
module
):
pass
def
_pre_context_exec
(
self
):
pass
def
_post_context_exec
(
self
):
pass
colossalai/zero/init_ctx/init_context.py
View file @
eb1b8990
...
@@ -13,82 +13,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
...
@@ -13,82 +13,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
contextlib
import
AbstractContextManager
from
contextlib
import
AbstractContextManager
from
colossalai.utils
import
InsertPostInitMethodToModuleSubClasses
def
_substitute_init_recursively
(
cls
,
func
):
for
subcls
in
cls
.
__subclasses__
():
_substitute_init_recursively
(
subcls
,
func
)
func
(
subcls
)
class
InsertPostInitMethodToModuleSubClasses
(
object
):
def
__init__
(
self
,
default_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
self
.
_old_default_dtype
=
None
self
.
_default_dtype
=
default_dtype
def
__enter__
(
self
):
r
"""
Enter the context scope.
"""
if
self
.
_default_dtype
is
not
None
:
self
.
_old_default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
self
.
_default_dtype
)
def
preprocess_after
(
f
):
@
functools
.
wraps
(
f
)
def
wrapper
(
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
f
(
module
,
*
args
,
**
kwargs
)
self
.
_post_init_method
(
module
)
return
wrapper
def
_enable_class
(
cls
):
cls
.
_old_init
=
cls
.
__init__
cls
.
__init__
=
preprocess_after
(
cls
.
__init__
)
# The function is called during init subclass.
def
_init_subclass
(
cls
,
**
kwargs
):
cls
.
__init__
=
preprocess_after
(
cls
.
__init__
)
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function.
_substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_enable_class
)
# holding on to the current __init__subclass__ for exit
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
=
(
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
)
# Replace .__init__() for future subclasses of torch.nn.Module
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
=
classmethod
(
_init_subclass
)
self
.
_pre_context_exec
()
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
_default_dtype
is
not
None
:
torch
.
set_default_dtype
(
self
.
_old_default_dtype
)
def
_disable_class
(
cls
):
cls
.
__init__
=
cls
.
_old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module
_substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_disable_class
)
# Replace .__init__() for future subclasses of torch.nn.Module
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
=
(
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
)
self
.
_post_context_exec
()
# Now that we cleaned up the metaclass injection, raise the exception.
if
exc_type
is
not
None
:
return
False
# To be implemented by inheriting classes
def
_post_init_method
(
self
,
module
):
pass
def
_pre_context_exec
(
self
):
pass
def
_post_context_exec
(
self
):
pass
class
ZeroContextConfig
(
object
):
class
ZeroContextConfig
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment