Unverified Commit 3defa32a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Support TP-compatible Torch AMP and Update trainer API (#27)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
parent 2b05de4c
...@@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" . ...@@ -42,26 +42,18 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
```python ```python
import colossalai import colossalai
from colossalai.engine import Engine
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() engine, train_dataloader, test_dataloader = colossalai.initialize()
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
trainer = Trainer(engine=engine, trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True) verbose=True)
trainer.fit( trainer.fit(
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs, epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True, display_progress=True,
test_interval=5 test_interval=5
) )
......
from .builder import * from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler)
from .pipeline import ModelInitializer from .pipeline import ModelInitializer
__all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer'
]
...@@ -181,18 +181,6 @@ def build_transform(config): ...@@ -181,18 +181,6 @@ def build_transform(config):
return build_from_registry(config, TRANSFORMS) return build_from_registry(config, TRANSFORMS)
def build_pipe_alloc_policy(config):
"""Returns a pipeline allocation policy object constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: A pipeline allocation policy object
:rtype:
"""
return build_from_registry(config, PIPE_ALLOC_POLICY)
def build_data_sampler(config, dataset): def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler` """Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`. constructed from `config`.
...@@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None): ...@@ -235,7 +223,7 @@ def build_optimizer_wrapper(config, optimizer, model=None):
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_) return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch): def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
...@@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch): ...@@ -254,9 +242,16 @@ def build_lr_scheduler(config, optimizer, total_steps, num_steps_per_epoch):
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') mod_type = config_.pop('type')
# warmup epochs will overwrite warmup steps return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_)
if 'warmup_epochs' in config_:
warmup_epochs = config_.pop('warmup_epochs')
config_['warmup_steps'] = int(num_steps_per_epoch * warmup_epochs) def build_schedule(config):
return LR_SCHEDULERS.get_module(mod_type)(optimizer, total_steps, num_steps_per_epoch=num_steps_per_epoch, """Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
**config_)
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule`
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
"""
return build_from_registry(config, SCHEDULE)
from .amp_type import AMP_TYPE
from ._base_engine import Engine from ._base_engine import Engine
from .gradient_handler import * from .gradient_handler import *
from .schedule import * from .schedule import *
from .amp import *
__all__ = ['Engine'] __all__ = ['Engine']
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Optional from torch.nn import Module
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.builder import build_gradient_handler from colossalai.builder import build_gradient_handler
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
...@@ -9,162 +11,166 @@ from colossalai.core import global_context as gpc ...@@ -9,162 +11,166 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from torch.nn import Module from .schedule import BaseSchedule
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from .schedule import BaseSchedule, NoPipelineSchedule
class Engine: class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method """Basic engine class for training and evaluation. It runs a specific process method
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training.
:param train_dataloader: Dataloader in training
:param test_dataloader: Dataloader in evaluation
:param model: The neural network model :param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters :param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation :param step_schedule: Running schedule in :meth:`step`
:param schedule: Running schedule in :meth:`step` :param gradient_accumulation: Steps of gradient accumulation
:type train_dataloader: DataLoader, optional :param gradient_clipping: The norm of gradient clipping
:type test_dataloader: DataLoader, optional
:type model: Module :type model: Module
:type criterion: _Loss, optional :type optimizer: Optimizer
:type optimizer: Optimizer, optional :type step_schedule: BaseSchedule, optional
:type lr_scheduler: _LRScheduler, optional :type gradient_accumulation: int, optional
:type schedule: BaseSchedule, optional :type gradient_clipping: float, optional
""" """
def __init__(self, def __init__(self,
train_dataloader: Optional[DataLoader] = None, model: Module,
test_dataloader: Optional[DataLoader] = None, optimizer: Optimizer,
model: Module = None, criterion: _Loss,
criterion: _Loss = None, step_schedule: BaseSchedule,
optimizer: Optimizer = None, gradient_handlers: list = None,
lr_scheduler: Optional[_LRScheduler] = None, gradient_accumulation: int = 1,
schedule: BaseSchedule = None): gradient_clipping: float = 0.0,
self.train_dataloader = train_dataloader ):
self.test_dataloader = test_dataloader self._model = model
assert model is not None, "Engine requires a model" self._optimizer = optimizer
self.model = model self._criterion = criterion
self.criterion = criterion self._schedule = step_schedule
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler # schedule initialize
self.schedule = schedule if schedule is not None \ self._schedule.initialize(model, optimizer)
else NoPipelineSchedule()
# state
self.training = True # default
# gradient accumulation
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
self._grad_accum_size = gradient_accumulation
self._grad_clip = gradient_clipping
self._logger = get_global_dist_logger() self._logger = get_global_dist_logger()
# build gradient handler # build gradient handler
self._gradient_handlers = [] self._gradient_handlers = []
gradient_handler_cfg = []
if hasattr(gpc.config, 'gradient_handler'): if gradient_handlers is not None:
assert isinstance(gpc.config.gradient_handler, list), \ assert isinstance(gradient_handlers, list), \
f'argument gradient_handler_cfg expected type list, ' \ f'argument gradient_handler_cfg expected type list, ' \
f'but got type {type(gpc.config.gradient_handler)}' f'but got type {type(gradient_handlers)}'
gradient_handler_cfg = gpc.config.gradient_handler elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)): ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')] gradient_handlers = [dict(type='ZeROGradientHandler')]
self._logger.info( self._logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically " "Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration", "added even though not specified in the configuration",
ranks=[0]) ranks=[0])
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size( elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
ParallelMode.DATA) > 1: ParallelMode.DATA) > 1:
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] gradient_handlers = [dict(type='DataParallelGradientHandler')]
self._logger.info( self._logger.info(
"Data parallel training is detected, DataParallelGradientHandler is automatically " "Data parallel training is detected, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration", "added even though not specified in the configuration",
ranks=[0]) ranks=[0])
if len(gradient_handler_cfg) == 0:
if gradient_handlers is None:
self._logger.warning( self._logger.warning(
"No gradient handler is set up, please make sure you do not need " "No gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.", "to all-reduce the gradients after a training step.",
ranks=[0]) ranks=[0])
for cfg in gradient_handler_cfg: else:
handler = build_gradient_handler(cfg, self.model, self.optimizer) for cfg in gradient_handlers:
handler = build_gradient_handler(cfg, model, optimizer)
self._gradient_handlers.append(handler) self._gradient_handlers.append(handler)
self.schedule.initialize(self.train_dataloader, self.model, @property
self.criterion, self.optimizer, def model(self):
self.lr_scheduler) return self._model
self.forward_only = False
def handle_gradient(self): @property
"""Handles all-reduce operations of gradients across different parallel groups. def optimizer(self):
""" return self._optimizer
for handler in self._gradient_handlers:
handler.handle_gradient()
def set_dataloader(self, data: DataLoader, train: bool = True): @property
"""Sets dataloader in training or evaluation. def criterion(self):
return self._criterion
:param data: Dataloader to be set @property
:param train: Set training dataloader if True, otherwise evaluation dataloader def schedule(self):
:type data: DataLoader return self._schedule
:type train: bool
"""
if train:
self.train_dataloader = data
else:
self.test_dataloader = data
def get_model(self): @property
"""Returns the neural network model in the engine. def gradient_accumulation(self):
""" return self._grad_accum_size
return self.model
def get_optimizer(self):
"""Returns optimizier in the engine.
"""
return self.optimizer
def get_lr_scheduler(self): def handle_gradient(self):
"""Returns the learning rate scheduler in the engine. """Handles all-reduce operations of gradients across different parallel groups.
""" """
return self.lr_scheduler for handler in self._gradient_handlers:
handler.handle_gradient()
def train(self): def train(self):
"""Sets the model to training mode. """Sets the model to training mode.
""" """
self.forward_only = False self.training = True
self.schedule.train(dataloader=self.train_dataloader, mode=True) self._model.train()
def eval(self): def eval(self):
"""Sets the model to evaluation mode. """Sets the model to evaluation mode.
""" """
self.forward_only = True self.training = False
self.schedule.train(dataloader=self.test_dataloader, mode=False) self._model.eval()
def is_train(self): def step(self,
"""Returns True if it is in training, otherwise False. data_iter,
""" is_last_iteration: bool = False,
return not self.forward_only return_loss=True):
def get_lr(self):
"""Gets current learning rate.
"""
return self.schedule.get_lr()
def step(self, return_loss=True):
"""A running step based on the schedule. Usually, it runs a training or """A running step based on the schedule. Usually, it runs a training or
evaluation over a batch of dataset. evaluation over a batch of dataset.
:param data_iter: Data iterator of the dataset
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
:param return_loss: loss will be returned if True :param return_loss: loss will be returned if True
:type return_loss: bool :type data_iter: Iterator
:type is_last_iteration: bool, optional
:type return_loss: bool, optional
:return: (output, lablel, loss) :return: (output, lablel, loss)
""" """
self.schedule.zero_grad(forward_only=self.forward_only) if self.training:
self._optimizer.zero_grad()
output, label, loss = self.schedule.forward_backward_step(
forward_only=self.forward_only, return_loss=return_loss) # differentiate training and eval with grad accum
if self.training:
if not self.forward_only: for i in range(self._grad_accum_size):
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=False,
grad_accum_size=self._grad_accum_size,
return_loss=return_loss)
if i == self._grad_accum_size - 1:
# all reduce gradients # all reduce gradients
self.handle_gradient() self.handle_gradient()
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
self.schedule.step() else:
output, label, loss = self._schedule.forward_backward_step(
data_iter, self._model, self._criterion, self._optimizer,
forward_only=True,
grad_accum_size=1,
return_loss=return_loss)
# consume the remaining dataset left out due to gradient accumulation
if is_last_iteration:
while True:
try:
_ = next(data_iter)
except StopIteration:
break
return output, label, loss return output, label, loss
from .grad_scaler import GradScaler
from .amp_type import AMP_TYPE
This diff is collapsed.
...@@ -5,125 +5,85 @@ from abc import ABC, abstractmethod ...@@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
import torch import torch
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
class BaseSchedule(ABC): class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation. """A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update.
For the convenience to enable FP16, we aggreate all codes that contain the
control of FP16 in class schedule.
""" """
def __init__(self): def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger() self.logger = get_global_dist_logger()
@property @staticmethod
@abstractmethod def _move_tensor(element):
def num_steps(self): if torch.is_tensor(element):
"""The number of batches in training or evaluation. if not element.is_cuda:
""" return element.to(get_current_device()).detach()
pass return element
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation def _move_to_device(self, data):
:param model: The neural network model if isinstance(data, (tuple, list)):
:param criterion: Criterion for calculating loss data = tuple([self._move_tensor(d) for d in data])
:param optimizer: Optimizer for updating the parameters elif torch.is_tensor(data):
:param lr_scheduler: Learning rate scheduler in the process data = data.to(get_current_device()).detach()
""" return data
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self): def load_batch(self, data_iter):
"""Loads a batch of dataset. It returns the data and labels which are """Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's. already in the same GPU as where the model's.
:return: (data, label) :return: (data, label)
:rtype: (Tensor, Tensor) :rtype: (Tensor, Tensor)
""" """
self.check_initialized() if data_iter is None:
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.') raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter) data, label = next(data_iter)
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
def _move_to_device(self, data): def initialize(self, model, optimizer):
if isinstance(data, ( """Initializes the model and the optimizer before training.
tuple, This is often used in FP16 training.
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
:param dataloader: Dataloader to be used :param model: The neural network model
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode. :param optimizer: Optimizer for updating the parameters
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
""" """
if not forward_only: return model, optimizer
self.check_initialized()
self.optimizer.zero_grad()
def get_lr(self): @abstractmethod
"""Returns the current learning rate. def forward_backward_step(self,
""" data_iter,
if self.lr_scheduler is not None: model,
return self.lr_scheduler.get_lr()[0] criterion,
else: optimizer=None,
return self.optimizer.param_groups[0]['lr'] forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
def step(self): :param data_iter: Data iterator of the dataset
"""Updates the parameters and learning rate with the optimizer. :param model: Model used in training or evaluation
:param optimizer: Optimizer used in training or evaluation
:param criterion: Loss function
:param forward_only: If True, the process won't include backward
:param grad_accum_size: Steps of gradient accumulation
:param return_loss: If False, the loss won't be returned
""" """
self.check_initialized() pass
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod @abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True): def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
"""The process function over a batch of dataset for training or evaluation. """Updates the parameters with the optimizer.
:param forward_only: If True, the process won't include backward. :param model: The neural network model
:param return_loss: If False, the loss won't be returned. :param optimizer: Optimizer for updating the parameters
:param grad_clipping: The norm of gradient clipping
:type grad_clipping: float, optional
""" """
pass pass
...@@ -4,19 +4,24 @@ ...@@ -4,19 +4,24 @@
try: try:
import apex.amp as apex_amp import apex.amp as apex_amp
except: except:
print('apex is required for mixed precision training') pass
try: try:
import torch.cuda.amp as torch_amp import torch.cuda.amp as torch_amp
except: except:
print('PyTorch amp is not supported with the current PyTorch version') pass
from typing import Iterable
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from ._utils import convert_to_fp16 from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16, convert_to_fp32
from ..amp import AMP_TYPE, GradScaler
class NoPipelineSchedule(BaseSchedule): class NoPipelineSchedule(BaseSchedule):
...@@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule): ...@@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
:type amp_type: AMP_TYPE :type amp_type: AMP_TYPE
:type amp_config: dict :type amp_config: dict
""" """
def __init__( def __init__(
self, self,
amp_type: AMP_TYPE = None, amp_type: AMP_TYPE = None,
...@@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule): ...@@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \ assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex' 'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False self.use_zero_level_2_3 = False
if amp_type is not None: if amp_type is not None:
...@@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule): ...@@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
self.fp16 = False self.fp16 = False
self.amp_type = None self.amp_type = None
@property def initialize(self, model: nn.Module, optimizer: Optimizer):
def num_steps(self): if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)): ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL' assert self.amp_type != AMP_TYPE.PARALLEL, \
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16: if self.fp16:
if self.amp_type == AMP_TYPE.TORCH: if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg) self._torch_amp_scaler = GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX: elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize( model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
self.model, self.optimizer, **self.amp_cfg)
return model, optimizer
def forward_backward_step(self, forward_only=False, return_loss=True):
def forward_backward_step(self,
data_iter: Iterable,
model: nn.Module,
criterion: nn.modules.loss._Loss,
optimizer: Optimizer = None,
forward_only: bool = False,
grad_accum_size: int = 1,
return_loss: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model. """The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False. The returned labels and loss will None if :attr:`return_loss` is False.
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param model: Model for training and inference
:param criterion: Loss function for training
:param optimizer: Optimizer used for training
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
:param grad_accum_size: The number of iterations for gradient accumulation
:param return_loss: Loss will be returned if True
:type data_iter: Iterator
:type model: torch.nn.Module
:type criterion: torch.nn.modules.loss._Loss
:type optimizer: torch.optim.Optimizer
:type forward_only: bool, optional
:type grad_accum_size: int
:type return_loss: bool, optional
:return: (output, label, loss) :return: (output, label, loss)
""" """
assert forward_only or return_loss, \ assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch() data, label = self.load_batch(data_iter)
loss = None loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward # forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH: if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast(): with torch_amp.autocast():
output = self.model(*data) output = model(*data)
if not isinstance(output, (tuple, list)): if not isinstance(output, (tuple, list)):
output = (output,) output = (output,)
if return_loss: if return_loss:
loss = self.criterion(*output, *label) loss = criterion(*output, *label)
else: else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL: if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data) data = convert_to_fp16(data)
output = self.model(*data) output = model(*data)
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
output = convert_to_fp32(output)
if not isinstance(output, (tuple, list)): if not isinstance(output, (tuple, list)):
output = (output,) output = (output,)
if return_loss: if return_loss:
loss = self.criterion(*output, *label) loss = criterion(*output, *label)
loss /= grad_accum_size
if not forward_only: if not forward_only:
# backward # backward
if self.use_zero_level_2_3: if self.use_zero_level_2_3:
self.optimizer.backward(loss) optimizer.backward(loss)
elif self.fp16: elif self.fp16:
if self.amp_type == AMP_TYPE.APEX: if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss, with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH: elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward() self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL: elif self.amp_type == AMP_TYPE.PARALLEL:
loss = self.optimizer.scale_loss(loss) loss = optimizer.scale_loss(loss)
loss.backward() loss.backward()
# scale back to display the original value in logs # scale back to display the original value in logs
loss.div_(self.optimizer.grad_scaler.scale) loss.div_(optimizer.grad_scaler.scale)
else: else:
loss.backward() loss.backward()
if return_loss: if return_loss:
return output, label, loss return output, label, loss * grad_accum_size
else: else:
return output, None, None return output, None, None
def step(self): def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
# step optimizer # step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH: if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.step(self.optimizer) if grad_clipping > 0.0:
self._torch_amp_scaler.unscale_(optimizer)
clip_grad_norm_fp32(model.parameters(), grad_clipping)
self._torch_amp_scaler.step(optimizer)
self._torch_amp_scaler.update() self._torch_amp_scaler.update()
else: else:
self.optimizer.step() if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
clip_grad_norm_fp32(model.parameters(), grad_clipping)
# update lr scheduler optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
...@@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2, ...@@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16 from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE from ..amp import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]): def squeeze(x: Union[Tensor, tuple, list]):
...@@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule): ...@@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
) )
# Pipeline schedule just puts data in memory # Pipeline schedule just puts data in memory
def load_batch(self): def load_batch(self, data_iter):
self.check_initialized() if data_iter is None:
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.') raise RuntimeError('Dataloader is not defined.')
self.batch_pos = 0 self.batch_pos = 0
data, label = next(self.data_iter) data, label = next(data_iter)
self.batch_data, self.batch_label = \ self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label) self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0] batch_size = self.batch_data.shape[0]
...@@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule): ...@@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
self.batch_pos += self.microbatch_size self.batch_pos += self.microbatch_size
return (data,), (label,) return (data,), (label,)
@property def initialize(self, model, optimizer):
def num_steps(self): if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
raise TypeError( raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3" "Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
) )
...@@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule): ...@@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
'default tensor dtype is set to torch.half for fp16 training', 'default tensor dtype is set to torch.half for fp16 training',
ranks=[0]) ranks=[0])
def forward_step(self, input_tensor, return_tensors, return_loss=True): def forward_step(self, model, criterion, input_tensor, return_tensors,
grad_accum_size, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
...@@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule): ...@@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
if self.amp_type == AMP_TYPE.PARALLEL: if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor) input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor) input_tensor = squeeze(input_tensor)
output_tensor = self.model(input_tensor) output_tensor = model(input_tensor)
output_tensor = squeeze(output_tensor) output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss: if return_loss:
input_tensor, label = self.load_micro_batch() input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, * loss_reduced = criterion(output_tensor, *label) \
label) / self.num_microbatches / (self.num_microbatches * grad_accum_size)
return_tensors.append( return_tensors.append(
tuple((output_tensor, label[0], loss_reduced))) tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced return loss_reduced
...@@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
else: else:
return output_tensor return output_tensor
def backward_step(self, input_tensor, output_tensor, output_tensor_grad): def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the """Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage). Returns the gradients with respect to the input tensor (None if first stage).
...@@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
# Backward pass. # Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL: if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = self.optimizer.scale_loss(output_tensor) output_tensor = optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
...@@ -197,7 +182,14 @@ class PipelineSchedule(BaseSchedule): ...@@ -197,7 +182,14 @@ class PipelineSchedule(BaseSchedule):
return input_tensor_grad return input_tensor_grad
def forward_backward_step(self, forward_only=True, return_loss=True): def forward_backward_step(self,
data_iter,
model,
criterion,
optimizer=None,
forward_only=False,
grad_accum_size: int = 1,
return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
...@@ -207,7 +199,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -207,7 +199,7 @@ class PipelineSchedule(BaseSchedule):
assert forward_only or return_loss, \ assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch() self.load_batch(data_iter)
num_warmup_microbatches = \ num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) - (gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1) gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
...@@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule): ...@@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape) ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape) input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(input_tensor, output_tensor = self.forward_step(
return_tensors, model, criterion,
return_loss=return_loss) input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
)
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker) fs_checker = send_tensor_meta(output_tensor, fs_checker)
...@@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule): ...@@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(input_tensor, output_tensor = self.forward_step(
return_tensors, model, criterion,
return_loss=return_loss) input_tensor, return_tensors,
grad_accum_size, return_loss=return_loss
)
if forward_only: if forward_only:
send_forward(output_tensor) send_forward(output_tensor)
...@@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule): ...@@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(input_tensor, input_tensor_grad = self.backward_step(
output_tensor, optimizer,
output_tensor_grad) input_tensor, output_tensor,
output_tensor_grad
)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
...@@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule): ...@@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
output_tensor_grad = recv_backward(bt_shape) output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(input_tensor, input_tensor_grad = self.backward_step(
output_tensor, optimizer,
output_tensor_grad) input_tensor, output_tensor,
output_tensor_grad
)
send_backward(input_tensor_grad) send_backward(input_tensor_grad)
...@@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule): ...@@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
output, label, loss = tuple(map(list, zip(*return_tensors))) output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0), return (torch.cat(output, dim=0),
torch.cat(label, dim=0), torch.cat(label, dim=0),
sum(loss)) sum(loss) * grad_accum_size)
else: else:
return tuple((torch.cat(return_tensors, dim=0), None, None)) return tuple((torch.cat(return_tensors, dim=0), None, None))
else: else:
return tuple((None, None, None)) return tuple((None, None, None))
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
optimizer.step()
...@@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]): ...@@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
else: else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}") raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret return ret
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.float()
elif isinstance(data, (list, tuple)):
ret = [val.float() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
...@@ -6,18 +6,20 @@ import pprint ...@@ -6,18 +6,20 @@ import pprint
import random import random
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, Optional, Union from typing import Callable, Iterable, Optional, Union
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss, from .builder import (ModelInitializer, build_dataset, build_loss,
build_lr_scheduler, build_model, build_optimizer, build_model, build_optimizer,
build_optimizer_wrapper) build_optimizer_wrapper, build_schedule)
from .context import Config, ParallelMode from .context import Config, ParallelMode
from .core import global_context as gpc from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp from .utils import get_current_device, sync_model_param_in_dp
...@@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None, ...@@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None,
backend: str = None, backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None, train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None, test_dataloader: Optional[Union[Iterable, Callable]] = None,
): ) -> Tuple[Engine, DataLoader, DataLoader]:
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config). '''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
:param config: config file or config file path are both acceptable :param config: config file or config file path are both acceptable
...@@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None, ...@@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None,
:type train_dataloader: Optional[Union[Iterable, Callable]], optional :type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None :param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional :type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler) :return: (engine, train_dataloader, test_dataloader, criterion)
:rtype: tuple :rtype: tuple
''' '''
# initialize distributed environment # initialize distributed environment
...@@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None, ...@@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None,
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer) optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0]) logger.info('Optimizer is created', ranks=[0])
lr_scheduler = None # build schedule and engine
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule
if hasattr(gpc.config, 'fp16'): if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy() amp_cfg = gpc.config.fp16.copy()
...@@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None, ...@@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None,
amp_type = None amp_type = None
amp_cfg = None amp_cfg = None
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: engine_cfg = gpc.config.get('engine', dict())
assert hasattr(gpc.config, schedule_cfg = engine_cfg.pop('schedule', None)
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
schedule_type = None
if schedule_cfg is not None:
schedule_type = schedule_cfg.get('type', None)
if schedule_type is not None:
# run customized schedule
schedule_cfg['amp_type'] = amp_type
schedule_cfg['amp_config'] = amp_cfg
schedule = build_schedule(schedule_cfg)
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert schedule_cfg is not None, \
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule( schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy()) amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
else: else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg) schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
step_schedule=schedule,
**gpc.config.get('engine', dict())
)
return engine, train_dataloader, test_dataloader
...@@ -7,6 +7,7 @@ from torch import Tensor ...@@ -7,6 +7,7 @@ from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
def matmul_2d(a, def matmul_2d(a,
...@@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB` """Matrix multiplication for :math:`C = AB`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
...@@ -120,10 +122,11 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -120,10 +122,11 @@ class Matmul_AB_2D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward( with torch.no_grad():
None, A_grad = Matmul_ABT_2D.apply(
output_grad, B, output_grad, B,
ctx.summa_dim, ctx.A_shape, ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.row_rank, ctx.col_rank,
...@@ -134,8 +137,7 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -134,8 +137,7 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.pipeline_parallel_size, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size ctx.tensor_parallel_size
) )
B_grad = Matmul_ATB_2D.forward( B_grad = Matmul_ATB_2D.apply(
None,
A, output_grad, A, output_grad,
ctx.summa_dim, ctx.B_shape, ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.row_rank, ctx.col_rank,
...@@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function): ...@@ -153,6 +155,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T` """Matrix multiplication for :math:`C = AB^T`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
...@@ -214,10 +217,12 @@ class Matmul_ABT_2D(torch.autograd.Function): ...@@ -214,10 +217,12 @@ class Matmul_ABT_2D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
A_grad = Matmul_AB_2D.forward(
None, with torch.no_grad():
A_grad = Matmul_AB_2D.apply(
output_grad, B, output_grad, B,
ctx.summa_dim, ctx.A_shape, ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.row_rank, ctx.col_rank,
...@@ -228,8 +233,7 @@ class Matmul_ABT_2D(torch.autograd.Function): ...@@ -228,8 +233,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
ctx.pipeline_parallel_size, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size ctx.tensor_parallel_size
) )
B_grad = Matmul_ATB_2D.forward( B_grad = Matmul_ATB_2D.apply(
None,
output_grad, A, output_grad, A,
ctx.summa_dim, ctx.B_shape, ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.row_rank, ctx.col_rank,
...@@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -247,6 +251,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB` """Matrix multiplication for :math:`C = A^TB`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
A: Tensor, A: Tensor,
B: Tensor, B: Tensor,
...@@ -308,10 +313,12 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -308,10 +313,12 @@ class Matmul_ATB_2D(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors A, B = ctx.saved_tensors
A_grad = Matmul_ABT_2D.forward(
None, with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
B, output_grad, B, output_grad,
ctx.summa_dim, ctx.A_shape, ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.row_rank, ctx.col_rank,
...@@ -322,8 +329,7 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -322,8 +329,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
ctx.pipeline_parallel_size, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size ctx.tensor_parallel_size
) )
B_grad = Matmul_AB_2D.forward( B_grad = Matmul_AB_2D.apply(
None,
A, output_grad, A, output_grad,
ctx.summa_dim, ctx.B_shape, ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.row_rank, ctx.col_rank,
...@@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function): ...@@ -341,6 +347,7 @@ class Add_Bias_2D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b` """Matrix add bias: :math:`C = A + b`
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
input: Tensor, input: Tensor,
bias: Tensor, bias: Tensor,
...@@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function): ...@@ -384,6 +391,7 @@ class Add_Bias_2D(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank row_rank = ctx.row_rank
col_rank = ctx.col_rank col_rank = ctx.col_rank
...@@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function): ...@@ -423,6 +431,7 @@ class Add_Bias_2D(torch.autograd.Function):
class _LayerNorm_2D(torch.autograd.Function): class _LayerNorm_2D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, def forward(ctx: Any,
input: Tensor, input: Tensor,
E_x: Tensor, E_x: Tensor,
...@@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function): ...@@ -440,6 +449,7 @@ class _LayerNorm_2D(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_parallel_mode = ctx.row_parallel_mode row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode col_parallel_mode = ctx.col_parallel_mode
...@@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function): ...@@ -492,6 +502,7 @@ class _LayerNorm_2D(torch.autograd.Function):
class _ViT_Split_Input_2D(torch.autograd.Function): class _ViT_Split_Input_2D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, def forward(ctx: Any,
inputs: Tensor, inputs: Tensor,
batch_size: int, batch_size: int,
...@@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function): ...@@ -509,6 +520,7 @@ class _ViT_Split_Input_2D(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [b/q, s, h/q] # output_grad: [b/q, s, h/q]
# grads: [b, s, h/q] # grads: [b, s, h/q]
......
from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR from .cosine import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
from .linear import LinearWarmupLR, LinearWarmupDecay from .linear import LinearWarmupLR
from .multistep import MultiStepLR, MultiStepWarmupLR from .multistep import MultiStepLR, MultiStepWarmupLR
from .onecycle import OneCycleLR from .onecycle import OneCycleLR
from .poly import PolynomialLR, PolynomialWarmupLR from .poly import PolynomialLR, PolynomialWarmupLR
......
...@@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler): ...@@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1, def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
**kwargs):
base_scheduler = _CosineAnnealingLR( base_scheduler = _CosineAnnealingLR(
optimizer, total_steps - warmup_steps, eta_min=eta_min) optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) super().__init__(optimizer, warmup_steps, base_scheduler)
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
......
...@@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler): ...@@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
:param last_epoch: The index of last epoch, defaults to -1 :param last_epoch: The index of last epoch, defaults to -1
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
if warmup_epochs < 0: self.warmup_epochs = int(warmup_epochs)
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
self.warmup_epochs = warmup_epochs
self.after_scheduler = after_scheduler self.after_scheduler = after_scheduler
self.finished = False self.finished = False
super().__init__(optimizer, last_epoch) super().__init__(optimizer, last_epoch)
...@@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler): ...@@ -79,14 +76,10 @@ class WarmupScheduler(_LRScheduler):
if self.last_epoch >= self.warmup_epochs: if self.last_epoch >= self.warmup_epochs:
if not self.finished: if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs self.after_scheduler.base_lrs = self.base_lrs
# reset lr to base_lr
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
group['lr'] = base_lr
self.finished = True self.finished = True
with _enable_get_lr_call(self.after_scheduler):
return self.after_scheduler.get_lr() return self.after_scheduler.get_lr()
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs] return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
def step(self, epoch=None): def step(self, epoch=None):
if self.finished: if self.finished:
......
...@@ -28,18 +28,3 @@ class LinearWarmupLR(_LRScheduler): ...@@ -28,18 +28,3 @@ class LinearWarmupLR(_LRScheduler):
else: else:
return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr for lr in
self.base_lrs] self.base_lrs]
@LR_SCHEDULERS.register_module
class LinearWarmupDecay(_LRScheduler):
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, last_epoch: int = -1, **kwargs):
self.warmup_steps = int(warmup_steps)
self.total_steps = total_steps
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
return [(self.last_epoch + 1) / self.warmup_steps * lr for lr in self.base_lrs]
else:
return [(self.total_steps - self.last_epoch - 1) / (self.total_steps - self.warmup_steps) * lr for lr in
self.base_lrs]
...@@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR): ...@@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
if num_steps_per_epoch <= 0:
raise ValueError(
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
milestones = [v * num_steps_per_epoch for v in milestones]
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
...@@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler): ...@@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler):
""" """
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None, def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs): gamma: float = 0.1, last_epoch: int = -1, **kwargs):
if len(milestones) == 0: if len(milestones) == 0:
raise ValueError('milestones cannot be empty') raise ValueError('milestones cannot be empty')
if num_steps_per_epoch <= 0: milestones = [
raise ValueError( v - warmup_steps for v in milestones if v >= warmup_steps]
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v *
num_steps_per_epoch >= warmup_steps]
base_scheduler = _MultiStepLR(optimizer, milestones=milestones, base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
gamma=gamma) gamma=gamma)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
from torch.optim.lr_scheduler import StepLR as _StepLR from torch.optim.lr_scheduler import StepLR as _StepLR
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from colossalai.registry import LR_SCHEDULERS from colossalai.registry import LR_SCHEDULERS
...@@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR): ...@@ -25,11 +25,8 @@ class LambdaLR(_LambdaLR):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1, def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
last_epoch: int = -1) -> None: super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
def func(step): return lr_lambda(step // num_steps_per_epoch)
super().__init__(optimizer, func, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
...@@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR): ...@@ -51,11 +48,8 @@ class MultiplicativeLR(_MultiplicativeLR):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps, lr_lambda=None, num_steps_per_epoch: int = -1, def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1) -> None:
last_epoch: int = -1) -> None: super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
def func(step): return lr_lambda(step // num_steps_per_epoch)
super().__init__(optimizer, func, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
...@@ -79,14 +73,13 @@ class StepLR(_StepLR): ...@@ -79,14 +73,13 @@ class StepLR(_StepLR):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, num_steps_per_epoch: int = -1, def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.1, last_epoch: int = -1) -> None:
last_epoch: int = -1) -> None: super().__init__(optimizer, step_size,
super().__init__(optimizer, step_size * num_steps_per_epoch,
gamma=gamma, last_epoch=last_epoch) gamma=gamma, last_epoch=last_epoch)
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
class ExponentialLR(_LRScheduler): class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch. """Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr When last_epoch=-1, sets initial lr as lr
...@@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler): ...@@ -102,21 +95,6 @@ class ExponentialLR(_LRScheduler):
:type last_epoch: int, optional :type last_epoch: int, optional
""" """
def __init__(self, optimizer, total_steps, gamma: float = 1.0, num_steps_per_epoch: int = -1, def __init__(self, optimizer, total_steps, gamma: float = 1.0,
last_epoch: int = -1) -> None: last_epoch: int = -1) -> None:
self.gamma = gamma super().__init__(optimizer, gamma, last_epoch=last_epoch)
self.num_steps_per_epoch = num_steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
if self.last_epoch == 0:
return self.base_lrs
elif (self.last_epoch + 1) % self.num_steps_per_epoch == 0:
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
return [group['lr']
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** (self.last_epoch // self.num_steps_per_epoch)
for base_lr in self.base_lrs]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment