Commit 404ecbdc authored by zbian's avatar zbian
Browse files

Migrated project

parent 2ebaefc5
from ._base_schedule import BaseSchedule
from ._no_pipeline import NoPipelineSchedule
from ._pipeline import PipelineSchedule
__all__ = ['BaseSchedule', 'NoPipelineSchedule', 'PipelineSchedule']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
"""
def __init__(self):
self.initialized = False
self.logger = get_global_dist_logger()
@property
@abstractmethod
def num_steps(self):
"""The number of batches in training or evaluation.
"""
pass
def initialize(self,
dataloader=None,
model=None,
criterion=None,
optimizer=None,
lr_scheduler=None):
"""Initializes the schedule and set parameters before running.
:param dataloader: DataLoader in training or evaluation
:param model: The neural network model
:param criterion: Criterion for calculating loss
:param optimizer: Optimizer for updating the parameters
:param lr_scheduler: Learning rate scheduler in the process
"""
self.dataloader = dataloader
assert model is not None, "Schedule requires a model"
self.model = model
assert criterion is not None, "Schedule requires a criterion"
self.criterion = criterion
assert optimizer is not None, "Schedule requires an optimizer"
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.initialized = True
def check_initialized(self):
"""Checks whether the schedule is initialized.
"""
assert self.initialized, \
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
def load_batch(self):
"""Loads a batch of dataset. It returns the data and labels which are
already in the same GPU as where the model's.
:return: (data, label)
:rtype: (Tensor, Tensor)
"""
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
data, label = next(self.data_iter)
return self._move_to_device(data), self._move_to_device(label)
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
data = tuple([
d.to(get_current_device()).detach() for d in data
if torch.is_tensor(d)
])
elif torch.is_tensor(data):
data = data.to(get_current_device()).detach()
return data
def train(self, dataloader=None, mode=True):
"""Sets the dataloader to be used and turn the model to
training or evaluation mode.
:param dataloader: Dataloader to be used
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
"""
self.check_initialized()
if mode:
self.model.train()
else:
self.model.eval()
if dataloader is not None:
self.dataloader = dataloader
self.data_iter = iter(dataloader)
def zero_grad(self, forward_only=False):
"""Cleans gradients with the optimizer.
"""
if not forward_only:
self.check_initialized()
self.optimizer.zero_grad()
def get_lr(self):
"""Returns the current learning rate.
"""
if self.lr_scheduler is not None:
return self.lr_scheduler.get_lr()[0]
else:
return self.optimizer.param_groups[0]['lr']
def step(self):
"""Updates the parameters and learning rate with the optimizer.
"""
self.check_initialized()
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@abstractmethod
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function over a batch of dataset for training or evaluation.
:param forward_only: If True, the process won't include backward.
:param return_loss: If False, the loss won't be returned.
"""
pass
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
try:
import apex.amp as apex_amp
except:
print('apex is required for mixed precision training')
try:
import torch.cuda.amp as torch_amp
except:
print('PyTorch amp is not supported with the current PyTorch version')
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.amp_type import AMP_TYPE
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from ._utils import convert_to_fp16
from ._base_schedule import BaseSchedule
class NoPipelineSchedule(BaseSchedule):
"""A helper schedule class for no pipeline parallelism running environment.
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(
self,
amp_type: AMP_TYPE = None,
amp_config: dict = None,
):
super().__init__()
# mixed precision training
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
'unrecognised value for argument fp16, it can only be None, torch or apex'
# LSG: check compatibility
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
ParallelMode.TENSOR) > 1:
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
self.use_zero_level_2_3 = False
if amp_type is not None:
self.fp16 = True
self.amp_type = amp_type
if amp_config is not None:
assert isinstance(amp_config, dict), \
f'expected argument fp16_config to be type dictionary, but got {type(amp_config)}'
if self.amp_type == AMP_TYPE.TORCH:
# torch apex
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.APEX:
# apex amp
if amp_config is None:
amp_config = dict(opt_level='O2')
self.logger.warning(
'apex is deprecated, please consider using torch.cuda.amp instead.'
)
self.amp_cfg = amp_config
elif self.amp_type == AMP_TYPE.PARALLEL:
# use fp16 optimizer for tensor parallel training
if amp_config is None:
amp_config = dict()
self.amp_cfg = amp_config
else:
self.fp16 = False
self.amp_type = None
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
self.use_zero_level_2_3 = True
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
if self.fp16:
if self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
elif self.amp_type == AMP_TYPE.APEX:
self.model, self.optimizer = apex_amp.initialize(
self.model, self.optimizer, **self.amp_cfg)
def forward_backward_step(self, forward_only=False, return_loss=True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
data, label = self.load_batch()
loss = None
# LSG: leave for debug, make sure dataloader is deterministic
# if forward_only:
# img = data[0]
# rank = gpc.get_local_rank(ParallelMode.DATA)
# world_size = gpc.get_world_size(ParallelMode.DATA)
# group = gpc.get_group(ParallelMode.DATA)
# input_list = [img.clone() for _ in range(world_size)]
# output_list = [torch.empty_like(img) for _ in range(world_size)]
# output_list[rank] = img.clone()
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
# forward
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
with torch_amp.autocast():
output = self.model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
else:
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
data = convert_to_fp16(data)
output = self.model(*data)
if not isinstance(output, (tuple, list)):
output = (output,)
if return_loss:
loss = self.criterion(*output, *label)
if not forward_only:
# backward
if self.use_zero_level_2_3:
self.optimizer.backward(loss)
elif self.fp16:
if self.amp_type == AMP_TYPE.APEX:
with apex_amp.scale_loss(loss,
self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.scale(loss).backward()
elif self.amp_type == AMP_TYPE.PARALLEL:
loss = self.optimizer.scale_loss(loss)
loss.backward()
# scale back to display the original value in logs
loss.div_(self.optimizer.grad_scaler.scale)
else:
loss.backward()
if return_loss:
return output, label, loss
else:
return output, None, None
def step(self):
# step optimizer
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
self._torch_amp_scaler.step(self.optimizer)
self._torch_amp_scaler.update()
else:
self.optimizer.step()
# update lr scheduler
if self.lr_scheduler is not None:
self.lr_scheduler.step()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union
import torch.cuda
import torch.distributed as dist
from torch import Tensor
from colossalai.communication import *
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import get_current_device
from ._base_schedule import BaseSchedule
from ._utils import convert_to_fp16
from ..amp_type import AMP_TYPE
def squeeze(x: Union[Tensor, tuple, list]):
if isinstance(x, (tuple, list)):
return x[0]
else:
return x
class PipelineSchedule(BaseSchedule):
"""A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NoPipelineSchedule`.
:param num_microbatches: The number of microbatches
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type num_microbatches: int
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def __init__(self,
num_microbatches,
amp_type: AMP_TYPE = None,
amp_config: dict = None):
super().__init__()
self.num_microbatches = num_microbatches
self.data_sync = True # close after making sure data is identical
# amp
# LSGL: amp_config is not used, but leave here for future extension
self.amp_type = amp_type
self.amp_config = amp_config
if self.amp_type is not None:
assert self.amp_type == AMP_TYPE.PARALLEL, 'We only support AMP_TYPE.PARALLEL for pipeline training for now'
def _move_to_device(self, data):
if isinstance(data, (
tuple,
list,
)):
assert len(data) == 1, "Data tuple's length in pipeline should be 1"
data = data[0]
assert torch.is_tensor(data), "Data in pipeline should be tensor"
data = data.to(get_current_device()).detach()
return data
def _sync_data(self):
if gpc.is_first_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_global_rank()
dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_PREV)
)
if gpc.is_last_rank(ParallelMode.PIPELINE):
src_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
dist.broadcast(
tensor=self.batch_data,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
dist.broadcast(
tensor=self.batch_label,
src=src_rank,
group=gpc.get_group(ParallelMode.PIPELINE_NEXT)
)
# Pipeline schedule just puts data in memory
def load_batch(self):
self.check_initialized()
if self.data_iter is None:
raise RuntimeError('Dataloader is not defined.')
self.batch_pos = 0
data, label = next(self.data_iter)
self.batch_data, self.batch_label = \
self._move_to_device(data), self._move_to_device(label)
batch_size = self.batch_data.shape[0]
assert batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
if self.data_sync:
self._sync_data()
def _get_data_slice(self, tensor):
return tensor[self.batch_pos: self.batch_pos + self.microbatch_size]
def load_micro_batch(self):
data = self._get_data_slice(self.batch_data)
label = self._get_data_slice(self.batch_label)
self.batch_pos += self.microbatch_size
return (data,), (label,)
@property
def num_steps(self):
return len(self.dataloader)
def initialize(self,
dataloader,
model,
criterion,
optimizer,
lr_scheduler=None):
super().initialize(dataloader,
model,
criterion,
optimizer,
lr_scheduler=lr_scheduler)
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
)
# LSG: set default dtype to fp16 for communication
if self.amp_type == AMP_TYPE.PARALLEL:
torch.set_default_dtype(torch.half)
self.logger.info(
'default tensor dtype is set to torch.half for fp16 training',
ranks=[0])
def forward_step(self, input_tensor, return_tensors, return_loss=True):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
"""
if input_tensor is None:
input_tensor, label = self.load_micro_batch()
if self.amp_type == AMP_TYPE.PARALLEL:
input_tensor = convert_to_fp16(input_tensor)
input_tensor = squeeze(input_tensor)
output_tensor = self.model(input_tensor)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_loss:
input_tensor, label = self.load_micro_batch()
loss_reduced = self.criterion(output_tensor, *
label) / self.num_microbatches
return_tensors.append(
tuple((output_tensor, label[0], loss_reduced)))
return loss_reduced
else:
return_tensors.append(output_tensor)
return output_tensor
else:
return output_tensor
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
"""
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
output_tensor = self.optimizer.scale_loss(output_tensor)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_backward_step(self, forward_only=True, return_loss=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
:return: (output, label, loss)
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch()
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = min(num_warmup_microbatches,
self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
return_tensors = []
# Used for tensor meta information communication
ft_shape = None
bt_shape = None
fs_checker = True
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
send_forward(output_tensor)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = recv_forward(ft_shape)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self.forward_step(input_tensor,
return_tensors,
return_loss=return_loss)
if forward_only:
send_forward(output_tensor)
if not last_iteration:
input_tensor = recv_forward(ft_shape)
else:
output_tensor_grad = send_forward_recv_backward(
output_tensor, bt_shape)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, ft_shape)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(bt_shape)
input_tensor_grad = self.backward_step(input_tensor,
output_tensor,
output_tensor_grad)
send_backward(input_tensor_grad)
if len(return_tensors) > 0:
if return_loss:
output, label, loss = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
sum(loss))
else:
return tuple((torch.cat(return_tensors, dim=0), None, None))
else:
return tuple((None, None, None))
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Union, List
from torch import Tensor
def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
if isinstance(data, Tensor):
ret = data.half()
elif isinstance(data, (list, tuple)):
ret = [val.half() for val in data]
else:
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
return ret
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import argparse
import pprint
import random
from pathlib import Path
from typing import Callable, Iterable, Optional, Union
import numpy as np
import torch
from torch.utils.data import DataLoader
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss,
build_lr_scheduler, build_model, build_optimizer,
build_optimizer_wrapper)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
def parse_args():
'''Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: call the parse arguments function
:rtype: Namespace
'''
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
default=None,
help='the master address for distributed training')
parser.add_argument('--port',
type=str,
default=None,
help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for ')
parser.add_argument('--local_rank',
type=int,
help='rank for the default process group')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for torch.distributed')
return parser.parse_args()
def init_dist(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:raises Exception: raise exception when config type is wrong
'''
args = [config, local_rank, world_size, host, port, backend]
arg_given = [arg is not None for arg in args]
if not all(arg_given):
args = parse_args()
if config is None:
config = args.config
if local_rank is None:
local_rank = args.local_rank
if world_size is None:
world_size = args.world_size
if host is None:
host = args.host
if port is None:
port = args.port
if backend is None:
backend = args.backend
args = Config(
dict(config=config,
host=host,
port=port,
world_size=world_size,
local_rank=local_rank,
backend=backend))
# set distributed settings
dist_args = Config(
dict(local_rank=args.local_rank,
world_size=args.world_size,
backend=args.backend))
gpc.set_dist_args(dist_args)
# set config
if isinstance(args.config, dict):
cfg = args.config
elif isinstance(args.config, (str, Path)):
cfg = Config.from_file(args.config)
else:
raise Exception('Config type error: {}'.format(type(args.config)))
gpc.load_config(cfg)
# init dist groups
gpc.init_global_dist(args.host, args.port)
gpc.init_parallel_groups()
# init dist logger
init_global_dist_logger()
# set cuda device
if torch.cuda.is_available():
gpc.set_device()
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True
as it will result in mismatch between input data on the 1st
stage and label on the last stage
:param dataset: a :class:utils.data.dataset dataset
:param seed: random worker seed, defaults to 1024
:type seed: int, optional
:param add_sampler_if_possible: [description], defaults to False
:type add_sampler_if_possible: bool, optional
:return: a :class:utils.data.dataset dataloader
:rtype: torch.utils.data.dataset
'''
_kwargs = kwargs.copy()
if 'shuffle' in _kwargs:
shuffle = _kwargs.pop('shuffle')
else:
shuffle = False
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
sampler = DataParallelSampler(dataset, shuffle=shuffle)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**_kwargs)
def initialize(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None,
):
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
:rtype: tuple
'''
# initialize distributed environment
init_dist(config=config,
local_rank=local_rank,
world_size=world_size,
host=host,
port=port,
backend=backend)
# init logger
logger = get_global_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
# print config
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================", ranks=[0])
# cudnn
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# set seed, cuda seed is only set when cuda is avail
gpc.set_seed()
# return_items = list()
# check fp16 and zero
should_convert_model_to_half = False
should_wrap_fp16_optimizer = False
should_wrap_zero_optimizer_level_2_3 = False
if hasattr(gpc.config, 'fp16'):
fp16_mode = gpc.config.fp16.mode
if fp16_mode == AMP_TYPE.PARALLEL:
should_convert_model_to_half = True
should_wrap_fp16_optimizer = True
if hasattr(gpc.config, 'zero'):
should_wrap_zero_optimizer_level_2_3 = True
zero_type = gpc.config.zero.type
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
should_convert_model_to_half = True
assert not should_wrap_fp16_optimizer, \
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
# build model
logger.info('Building model ...', ranks=[0])
assert hasattr(
gpc.config, 'model'), "Build error: configuration 'model' is missing"
if gpc.pipeline_parallel_size > 1:
model = ModelInitializer(gpc.config.model, 1, verbose=True)
model = model.model_initialize()
else:
model = build_model(gpc.config.model)
if isinstance(model, BaseModel):
model.build_from_cfg()
model = model.to(get_current_device())
sync_model_param_in_dp(model)
logger.info('Model is created', ranks=[0])
if should_convert_model_to_half:
model = model.half()
logger.info("Model is cast to fp16", ranks=[0])
# training data
if callable(train_dataloader):
logger.info(
f'Build train data loader from {train_dataloader}', ranks=[0])
train_dataloader = train_dataloader()
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
logger.info('Preparing data ...', ranks=[0])
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
train_dataset = build_dataset(gpc.config.train_data.dataset)
logger.info('Train dataset is ready.', ranks=[0])
train_dataloader = get_dataloader(train_dataset,
gpc.config.get('seed', 1024),
True,
**gpc.config.train_data.dataloader,
)
logger.info(
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
if callable(test_dataloader):
logger.info(
f'Build test data loader from {test_dataloader}', ranks=[0])
test_dataloader = test_dataloader()
# testing data, allowed to be None
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
test_dataset = build_dataset(gpc.config.test_data.dataset)
test_dataloader = get_dataloader(
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
logger.info(
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
# build loss function
assert hasattr(gpc.config, 'loss'), \
'Build error: configuration \'loss\' is missing.'
criterion = build_loss(gpc.config.loss)
logger.info('Loss function is created', ranks=[0])
# build optimizer
assert hasattr(gpc.config, 'optimizer'), \
"Build error: configuration 'optimizer' is missing."
optim_type = gpc.config.optimizer.type
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
if is_pytorch_native_zero_level_1:
original_cfg_copy = gpc.config.optimizer.copy()
original_cfg_copy.pop('type')
cfg = dict(type=optim_type, process_group=gpc.get_group(
ParallelMode.DATA), **original_cfg_copy)
optimizer = build_optimizer(cfg, model)
else:
optimizer = build_optimizer(gpc.config.optimizer, model)
if should_wrap_zero_optimizer_level_2_3:
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
if should_wrap_fp16_optimizer:
# replace the field mode with type
fp16_cfg = gpc.config.fp16.copy()
amp_type = fp16_cfg.pop('mode')
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
fp16_cfg['type'] = 'FP16Optimizer'
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
amp_cfg.pop('mode')
else:
amp_type = None
amp_cfg = None
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert hasattr(gpc.config,
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
from colossalai.core import global_context as gpc
from .logging import DistributedLogger
__all__ = ['get_global_dist_logger', 'get_dist_logger', 'DistributedLogger', 'init_global_dist_logger']
_GLOBAL_LOGGER: DistributedLogger = None
def get_dist_logger(name, level='INFO', root_path: str = None, mode='a'):
return DistributedLogger(name=name, level=level, root_path=root_path, mode=mode)
def get_global_dist_logger():
assert _GLOBAL_LOGGER is not None, 'Global distributed logger is not initialized'
return _GLOBAL_LOGGER
def init_global_dist_logger():
rank = gpc.get_global_rank()
if hasattr(gpc.config, 'logging'):
logger = get_dist_logger(name=f'rank_{rank}', **gpc.config.logging)
else:
logger = get_dist_logger(name=f'rank_{rank}', level='INFO')
global _GLOBAL_LOGGER
assert _GLOBAL_LOGGER is None, 'Global distributed logger has already been initialized'
_GLOBAL_LOGGER = logger
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import logging
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
_FORMAT = 'colossalai - %(name)s - %(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=_FORMAT)
class DistributedLogger:
"""This is a distributed event logger class essentially based on :class:`logging`.
:param name: The name of the logger
:type name: str
:param level: The threshold for the logger. Logging messages which are less severe than `level`
will be ignored
:type level: str
:param root_path: The root path where logs are stored
:type root_path: str, optional
:param mode: The mode that the file is opened in. Defaults to 'a'
:type mode: str, optional
"""
def __init__(self, name, level='INFO', root_path: str = None, mode='a'):
self._logger = logging.getLogger(name)
self._logger.setLevel(getattr(logging, level))
if root_path is not None:
log_root_path = Path(root_path)
# create path if not exists
log_root_path.mkdir(parents=True, exist_ok=True)
log_path = log_root_path.joinpath(f'{name}.log')
file_handler = logging.FileHandler(log_path, mode)
file_handler.setLevel(getattr(logging, level))
formatter = logging.Formatter(_FORMAT)
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
if ranks is None:
getattr(self._logger, level)(message)
else:
local_rank = gpc.get_local_rank(parallel_mode)
if local_rank in ranks:
getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an info log message.
:param message:
:type message:
:param parallel_mode:
:type parallel_mode:
:param ranks:
:type ranks:
"""
self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a warning log message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores a debug log message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Stores an error log message.
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('error', message, parallel_mode, ranks)
from .data import *
from .layer import *
from .loss import *
from .lr_scheduler import *
from .model import *
from .optimizer import *
from .caltech101_dataset import Caltech101Dataset
from .cifar10_dataset import CIFAR10Dataset
from .sampler import *
import numpy as np
def pil_img_to_numpy(pil_img):
"""convert a PIL image to numpy nd-array
:param pil_img: a PIL image
:type pil_img: PIL.Image
:return: a nd-array
:rtype: numpy.ndarray
"""
np_img = np.array(pil_img)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from colossalai.builder import build_transform
class BaseDataset(Dataset, ABC):
def __init__(self, transform_pipeline: list):
transform_list = [build_transform(cfg) for cfg in transform_pipeline]
transform = transforms.Compose(transform_list)
self._transform_pipeline = transform
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import Caltech101
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class Caltech101Dataset(BaseDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = Caltech101(
transform=self._transform_pipeline, *args, **kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torchvision.datasets import CIFAR10
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module
class CIFAR10Dataset(BaseDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
:param transform_pipeline: A list of functions' config, which takes in an PIL image
and returns a transformed version
:type transform_pipeline: list
"""
def __init__(self, transform_pipeline: list, *args, **kwargs):
super().__init__(transform_pipeline)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() != 0:
dist.barrier()
self._dataset = CIFAR10(transform=self._transform_pipeline,
*args,
**kwargs)
if gpc.is_initialized(ParallelMode.GLOBAL) and gpc.get_global_rank() == 0:
dist.barrier()
def __len__(self):
return len(self._dataset)
def __getitem__(self, item):
"""
:param item: Index
:type item: int
:return: ((image,), (target,)) where the type of target specified by target_type.
:rtype: tuple
"""
img, label = self._dataset.__getitem__(item)
return (img,), (label,)
from .base_sampler import BaseSampler
from .data_parallel_sampler import DataParallelSampler
__all__ = ['BaseSampler', 'DataParallelSampler']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseSampler(ABC):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __iter__(self):
pass
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adpated from torch.utils.data.DistributedSampler
import math
from typing import TypeVar, Iterator
import torch
from torch.utils.data import Sampler, Dataset
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import SAMPLERS
T_co = TypeVar('T_co', covariant=True)
@SAMPLERS.register_module
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism
:param dataset: a Dataset instance
:type dataset: torch.utils.data.Dataset
:param shuffle: whether to shuffle data, defaults to False
:type shuffle: bool, optional
:param seed: the random seed, defaults to 0
:type seed: int, optional
:param drop_last: set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False
:type drop_last: bool, optional
"""
def __init__(self,
dataset: Dataset,
shuffle: bool = False,
seed: int = 0,
drop_last: bool = False) -> None:
self.dataset = dataset
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
self.rank = gpc.get_local_rank(ParallelMode.DATA)
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
# type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
(len(self.dataset) - self.num_replicas) / \
self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
# type: ignore[arg-type]
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
:param epoch: Epoch number.
:type epoch: int
"""
self.epoch = epoch
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .parallel_vision_transformer import *
from .vanilla_resnet import *
from .vanilla_vision_transformer import *
from .wrapper import *
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment