Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cc0ed7cf
Unverified
Commit
cc0ed7cf
authored
Nov 17, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 17, 2022
Browse files
[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)
parent
f8a7148d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
10 deletions
+12
-10
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+2
-2
colossalai/zero/utils/gemini_hook.py
colossalai/zero/utils/gemini_hook.py
+8
-6
docs/colossalai/colossalai.zero.utils.rst
docs/colossalai/colossalai.zero.utils.rst
+1
-1
docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst
docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst
+1
-1
No files found.
colossalai/nn/parallel/data_parallel.py
View file @
cc0ed7cf
...
...
@@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup
from
colossalai.tensor.colo_parameter
import
ColoParameter
,
ColoTensor
,
ColoTensorSpec
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.utils
import
get_current_device
from
colossalai.zero.utils.
zero
_hook
_v2
import
ZeROHook
V2
from
colossalai.zero.utils.
gemini
_hook
import
Gemini
ZeROHook
from
.reducer
import
Reducer
...
...
@@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP):
self
.
gemini_manager
=
gemini_manager
self
.
chunk_manager
:
ChunkManager
=
gemini_manager
.
chunk_manager
self
.
force_outputs_fp32
=
force_outputs_fp32
self
.
param_op_hook
=
ZeROHook
V2
(
gemini_manager
)
self
.
param_op_hook
=
Gemini
ZeROHook
(
gemini_manager
)
self
.
fp32_params
:
List
[
ColoTensor
]
=
[]
self
.
overflow_counter
=
0
self
.
grads_device
:
Dict
[
torch
.
Tensor
,
torch
.
device
]
=
{}
...
...
colossalai/zero/utils/
zero
_hook
_v2
.py
→
colossalai/zero/utils/
gemini
_hook.py
View file @
cc0ed7cf
import
torch
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.gemini
import
TensorState
from
enum
import
Enum
from
typing
import
List
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
partial
from
typing
import
List
import
torch
from
colossalai.gemini
import
TensorState
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor.param_op_hook
import
ParamOpHook
class
TrainingPhase
(
Enum
):
...
...
@@ -13,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD
=
1
class
ZeROHook
V2
(
ParamOpHook
):
class
Gemini
ZeROHook
(
ParamOpHook
):
def
__init__
(
self
,
gemini_manager
:
GeminiManager
)
->
None
:
super
().
__init__
()
...
...
docs/colossalai/colossalai.zero.utils.rst
View file @
cc0ed7cf
...
...
@@ -9,4 +9,4 @@ colossalai.zero.utils
:maxdepth: 2
colossalai.zero.utils.zero_hook
colossalai.zero.utils.
zero
_hook
_v2
colossalai.zero.utils.
gemini
_hook
docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst
View file @
cc0ed7cf
colossalai.zero.utils.zero\_hook\_v2
====================================
.. automodule:: colossalai.zero.utils.
zero
_hook
_v2
.. automodule:: colossalai.zero.utils.
gemini
_hook
:members:
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment