Unverified Commit 26b7aac0 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
parent b09adff7
...@@ -14,17 +14,16 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase ...@@ -14,17 +14,16 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai import colossalai
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.utils import get_static_torch_model
logger = get_dist_logger(__name__)
from .base import Strategy from .base import Strategy
from .ddp import DDPStrategy from .ddp import DDPStrategy
logger = get_dist_logger(__name__)
class ColossalAIStrategy(DDPStrategy): class ColossalAIStrategy(DDPStrategy):
""" """
......
...@@ -4,8 +4,8 @@ from typing import Optional, Set ...@@ -4,8 +4,8 @@ from typing import Optional, Set
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager from .region_manager import RegionManager
from .util import GlobalRuntimeInfo from .util import GlobalRuntimeInfo
......
from typing import List, Dict, Tuple from typing import Dict, List, Tuple
import torch import torch
from torch.fx import Node from torch.fx import Node
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class Region: class Region:
""" """
...@@ -52,15 +55,13 @@ class Region: ...@@ -52,15 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space. Map the parameters in the region to a contiguous memory space.
""" """
self.fp16_data = torch.zeros( self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
self.param_num, dtype=torch.half, device='cuda')
offset = 0 offset = 0
for param in self.fp16_params: for param in self.fp16_params:
param.data = param.data.cuda() param.data = param.data.cuda()
p_num = param.data.numel() p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset + param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num) self.param_to_range[param] = (offset, offset + p_num)
offset += p_num offset += p_num
...@@ -141,4 +142,4 @@ class Region: ...@@ -141,4 +142,4 @@ class Region:
def __update_params_ptr(self) -> None: def __update_params_ptr(self) -> None:
for param in self.fp16_params: for param in self.fp16_params:
begin, end = self.param_to_range[param] begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape) param.data = self.fp16_data[begin:end].view(param.data.shape)
\ No newline at end of file
...@@ -14,12 +14,12 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -14,12 +14,12 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.gemini.memory_tracer import MemStats
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import _convert_to_coloparam from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin from .plugin_base import Plugin
......
...@@ -10,8 +10,8 @@ from torch.nn.modules.loss import _Loss ...@@ -10,8 +10,8 @@ from torch.nn.modules.loss import _Loss
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine: class Engine:
......
...@@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(mciro_batch_data) return self._move_to_device(mciro_batch_data)
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
......
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .gemini_mgr import GeminiManager
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager',
'search_chunk_configuration'
]
...@@ -29,13 +29,12 @@ from colossalai.engine.schedule import ( ...@@ -29,13 +29,12 @@ from colossalai.engine.schedule import (
PipelineSchedule, PipelineSchedule,
get_tensor_shape, get_tensor_shape,
) )
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2 from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
def get_default_parser(): def get_default_parser():
......
...@@ -9,7 +9,7 @@ import torch.nn as nn ...@@ -9,7 +9,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_decrator from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator
class MoeExperts(nn.Module): class MoeExperts(nn.Module):
......
...@@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts ...@@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator
@no_shard_zero_decrator(is_replicated=True) @no_shard_zero_decrator(is_replicated=True)
......
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)
from .data_parallel import ColoDDP, ZeroDDP from .data_parallel import ColoDDP
from .gemini_parallel import GeminiDDP
from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper'] __all__ = [
'ColoDDP',
]
This diff is collapsed.
from typing import Optional
import torch
from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import MemStats
from .data_parallel import ZeroDDP
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
from typing import Tuple from .gemini import (
ColoInitContext,
import torch GeminiAdamOptimizer,
import torch.nn as nn GeminiDDP,
ZeroDDP,
from colossalai.logging import get_dist_logger ZeroOptimizer,
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 get_static_torch_model,
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 post_process_colo_init_ctx,
)
from ..nn.optimizer.zero_optimizer import ZeroOptimizer from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, __all__ = [
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
""" 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading ]
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP
from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
from .utils import get_static_torch_model
__all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
]
...@@ -3,10 +3,11 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple ...@@ -3,10 +3,11 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
class ChunkManager: class ChunkManager:
""" """
......
...@@ -5,9 +5,9 @@ import numpy as np ...@@ -5,9 +5,9 @@ import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment