Unverified Commit 26b7aac0 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
parent b09adff7
......@@ -5,10 +5,11 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored
from .manager import ChunkManager
from .search_utils import search_chunk_configuration
def safe_div(a, b):
if a == 0:
......
......@@ -3,10 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union
import torch
from torch import nn
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from .utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica
......@@ -89,6 +87,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._default_dist_spec = default_dist_spec
def _register_colo_modules(self):
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())
......
This diff is collapsed.
......@@ -5,10 +5,10 @@ from typing import List
import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum):
......
......@@ -4,10 +4,8 @@ from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from .memory_tracer import ChunkMemStatsCollector
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory
......
......@@ -10,12 +10,15 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import ZeroDDP
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
......@@ -316,3 +319,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
fake_params_list.append(fake_param)
group['params'] = fake_params_list
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)
from typing import Optional
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats
from .memstats_collector import MemStatsCollector
......
......@@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional
import torch
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from .param_runtime_order import OrderedParamGenerator
class MemStats(object):
......
import time
from typing import List, Optional
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
from typing import Optional
from .memory_monitor import SyncCudaMemoryMonitor
from .memory_stats import MemStats
......@@ -49,7 +44,7 @@ class MemStatsCollector:
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}"
f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
......@@ -75,6 +70,8 @@ class MemStatsCollector:
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
from colossalai.zero.legacy.gemini import StatefulTensor
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
self._memstats.record_max_cuda_model_data(cuda_mem)
......
import torch.nn
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from .memory_stats import MemStats
__all__ = ['RuntimeMemTracer']
......
......@@ -6,7 +6,7 @@ from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager
from colossalai.zero.gemini.chunk import ChunkManager
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
......
......@@ -5,11 +5,12 @@ from typing import Dict, List, Optional, Tuple, Type
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC):
need_mem_stats: bool = False
......
......@@ -6,9 +6,10 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
from .chunk import Chunk
def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered:
......@@ -77,7 +78,7 @@ def get_static_torch_model(zero_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero.gemini.gemini_ddp import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
......
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.logging import get_dist_logger
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator'
]
from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
__all__ = [
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment