Unverified Commit cc0ed7cf authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)

parent f8a7148d
...@@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup ...@@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer from .reducer import Reducer
...@@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP): ...@@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP):
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager) self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = [] self.fp32_params: List[ColoTensor] = []
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = {}
......
import torch
from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.gemini import TensorState
from enum import Enum
from typing import List
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum
from functools import partial from functools import partial
from typing import List
import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ParamOpHook
class TrainingPhase(Enum): class TrainingPhase(Enum):
...@@ -13,7 +15,7 @@ class TrainingPhase(Enum): ...@@ -13,7 +15,7 @@ class TrainingPhase(Enum):
BACKWARD = 1 BACKWARD = 1
class ZeROHookV2(ParamOpHook): class GeminiZeROHook(ParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None: def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__() super().__init__()
......
...@@ -9,4 +9,4 @@ colossalai.zero.utils ...@@ -9,4 +9,4 @@ colossalai.zero.utils
:maxdepth: 2 :maxdepth: 2
colossalai.zero.utils.zero_hook colossalai.zero.utils.zero_hook
colossalai.zero.utils.zero_hook_v2 colossalai.zero.utils.gemini_hook
colossalai.zero.utils.zero\_hook\_v2 colossalai.zero.utils.zero\_hook\_v2
==================================== ====================================
.. automodule:: colossalai.zero.utils.zero_hook_v2 .. automodule:: colossalai.zero.utils.gemini_hook
:members: :members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment