Unverified Commit fae6c92e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge branch 'main' into feature/shardformer

parents bd186784 ac178ca5
class Registry: class Registry:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here # TODO: refactor the registry classes used in colossalai.legacy.registry, colossalai.fx and here
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
...@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import ( ...@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
unwrap_optimizer, unwrap_optimizer,
) )
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
...@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): ...@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
...@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
class LowLevelZeroModel(ModelWrapper): def save_sharded_model(self,
model: nn.Module,
def __init__(self, module: nn.Module, stage: int, precision: str) -> None: checkpoint_path: str,
super().__init__(module) gather_dtensor: bool = True,
self.dtype = None prefix: Optional[str] = None,
if precision == 'fp16': max_shard_size: int = 1024,
self.dtype = torch.float16 use_safetensors: bool = False):
elif precision == 'bf16': assert isinstance(model, LowLevelZeroModel)
self.dtype = torch.bfloat16 super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
module = zero_model_wrapper(module, zero_stage=stage) use_safetensors)
if self.dtype is not None:
module = module.to(self.dtype) def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
module = module.to(get_current_device()) assert isinstance(model, LowLevelZeroModel)
self.module = module super().load_unsharded_model(model.module, checkpoint, strict)
self.convert_fn = None model.update_master_params()
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) def load_sharded_model(self,
model: LowLevelZeroModel,
def forward(self, *args, **kwargs): checkpoint_index_file: Path,
if self.convert_fn is not None: strict: bool = False,
args = tree_map(self.convert_fn, args) use_safetensors: bool = False,
kwargs = tree_map(self.convert_fn, kwargs) load_sub_module: bool = True):
return super().forward(*args, **kwargs) assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):
...@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
super().__init__() super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage self.stage = stage
self.precision = precision self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, self.zero_optim_kwargs = dict(
communication_dtype=communication_dtype, initial_scale=initial_scale,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
growth_interval=growth_interval, growth_interval=growth_interval,
hysteresis=hysteresis, hysteresis=hysteresis,
min_scale=min_scale, min_scale=min_scale,
max_scale=max_scale, max_scale=max_scale,
max_norm=max_norm, clip_grad_norm=max_norm,
norm_type=norm_type) reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose self.verbose = verbose
# set class name with stage, for better error message # set class name with stage, for better error message
...@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision) model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(), optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
optimizer, **self.zero_optim_kwargs,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose) verbose=self.verbose)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -15,8 +15,8 @@ from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING ...@@ -15,8 +15,8 @@ from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
......
...@@ -3,7 +3,7 @@ import math ...@@ -3,7 +3,7 @@ import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
import math import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
......
...@@ -6,7 +6,7 @@ import math ...@@ -6,7 +6,7 @@ import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from torch import distributed as dist from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from torch import distributed as dist from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor from .initializer_tensor import Initializer_Tensor
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
import torch.distributed as dist import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
......
...@@ -17,13 +17,13 @@ from torch.utils.data import DataLoader ...@@ -17,13 +17,13 @@ from torch.utils.data import DataLoader
from colossalai.amp import AMP_TYPE, convert_to_amp from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.legacy.builder.builder import build_gradient_handler
from colossalai.engine.gradient_accumulation import accumulate_gradient from colossalai.legacy.engine import Engine
from colossalai.engine.schedule import ( from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
from colossalai.legacy.engine.schedule import (
InterleavedPipelineSchedule, InterleavedPipelineSchedule,
NonPipelineSchedule, NonPipelineSchedule,
PipelineSchedule, PipelineSchedule,
......
from .model import ModelWrapper from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper'] __all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
...@@ -23,3 +23,14 @@ class ModelWrapper(nn.Module): ...@@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import inspect import inspect
from colossalai.registry import * from colossalai.legacy.registry import *
def build_from_config(module, config: dict): def build_from_config(module, config: dict):
...@@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer): ...@@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns: Returns:
An object of :class:`colossalai.engine.BaseGradientHandler` An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
""" """
config_ = config.copy() config_ = config.copy()
config_['model'] = model config_['model'] = model
......
...@@ -8,11 +8,17 @@ from torch import Tensor ...@@ -8,11 +8,17 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule from colossalai.legacy.engine.schedule import (
BaseSchedule,
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
)
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine: class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method """Basic engine class for training and evaluation. It runs a specific process method
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment