Unverified Commit 2b2dc1c8 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[pipeline] refactor the pipeline module (#1087)

* [pipeline] refactor the pipeline module

* polish code
parent bad5d4c0
from .builder import (build_schedule, build_lr_scheduler, build_model, from .builder import build_from_config, build_from_registry, build_gradient_handler
build_optimizer, build_layer, build_loss, build_hooks,
build_dataset, build_transform, build_data_sampler,
build_gradient_handler, build_ophooks)
from .pipeline import build_pipeline_model, build_pipeline_model_from_cfg
__all__ = [ __all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_gradient_handler', 'build_from_config', 'build_from_registry'
'build_layer', 'build_loss', 'build_hooks', 'build_dataset',
'build_transform', 'build_data_sampler', 'build_gradient_handler',
'build_pipeline_model', 'build_pipeline_model_from_cfg', 'build_ophooks'
] ]
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import inspect import inspect
from collections.abc import Iterable
from colossalai.registry import * from colossalai.registry import *
...@@ -64,84 +63,6 @@ def build_from_registry(config, registry: Registry): ...@@ -64,84 +63,6 @@ def build_from_registry(config, registry: Registry):
return obj return obj
def build_layer(config):
"""Returns a layer object of :class:`nn.Module` constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``LAYERS``.
Returns:
An object of :class:`torch.nn.Module`
"""
return build_from_registry(config, LAYERS)
def build_loss(config):
"""Returns a loss function object of :class:`torch.autograd.Function` constructed
from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``LOSSES``.
Returns:
An object of :class:`torch.nn.modules.loss._Loss`
"""
return build_from_registry(config, LOSSES)
def build_model(config):
"""Returns a model object of :class:`nn.Module` constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``MODELS``.
Returns:
An object of :class:`torch.nn.Module`
"""
return build_from_registry(config, MODELS)
def build_dataset(config):
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``DATASETS``.
Returns:
An object of :class:`torch.utils.data.Dataset`
"""
return build_from_registry(config, DATASETS)
def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``OPTIMIZERS``.
model (:class:`nn.Module`): A model containing parameters for the optimizer
Returns:
An object of :class:`torch.optim.Optimizer`
"""
config_ = config.copy()
config_['params'] = model.parameters()
return build_from_registry(config_, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer): def build_gradient_handler(config, model, optimizer):
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, """Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
`model` and `optimizer`. `model` and `optimizer`.
...@@ -160,100 +81,3 @@ def build_gradient_handler(config, model, optimizer): ...@@ -160,100 +81,3 @@ def build_gradient_handler(config, model, optimizer):
config_['model'] = model config_['model'] = model
config_['optimizer'] = optimizer config_['optimizer'] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER) return build_from_registry(config_, GRADIENT_HANDLER)
def build_hooks(config, trainer):
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``HOOKS``.
Returns:
An object of :class:`colossalai.trainer.hooks.BaseHook`
"""
config_ = config.copy()
config_['trainer'] = trainer
return build_from_registry(config_, HOOKS)
def build_ophooks(config):
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``OPHOOKS``.
Returns:
An object of :class:`colossalai.trainer.hooks.BaseOpHook`
"""
config_ = config.copy()
return build_from_registry(config_, OPHOOKS)
def build_transform(config):
"""Returns a transformation object of :class:`torchvision.transforms` constructed
from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``TRANSFORMS``.
Returns:
An object of :class:`torchvision.transforms`
"""
return build_from_registry(config, TRANSFORMS)
def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``DATA_SAMPLERS``.
dataset (:class:`torch.utils.data.Dataset`): An object of
:class:`torch.utils.data.Dataset` containing information
used in the construction of the return object
Returns:
An object of :class:`colossalai.utils.data_sampler.BaseSampler`
"""
config_ = config.copy()
config_['dataset'] = dataset
return build_from_registry(config_, DATA_SAMPLERS)
def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``lr_schedule``.
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing
parameters for the learning rate scheduler.
Returns:
An object of :class:`torch.optim.lr_scheduler`
"""
config_ = config.copy()
config_['optimizer'] = optimizer
return build_from_registry(config_, LR_SCHEDULERS)
def build_schedule(config):
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
Args:
config (dict or :class:`colossalai.context.Config`): A python dict or
a :class:`colossalai.context.Config` object containing information
used in the construction of the ``Schedule``.
Returns:
An object of :class:`colossalai.engine.schedule.BaseSchedule`
"""
return build_from_registry(config, SCHEDULE)
...@@ -2,6 +2,5 @@ from .layer import * ...@@ -2,6 +2,5 @@ from .layer import *
from .loss import * from .loss import *
from .lr_scheduler import * from .lr_scheduler import *
from .metric import * from .metric import *
from .model import *
from .optimizer import * from .optimizer import *
from ._ops import * from ._ops import *
from .lambda_wrapper import LambdaWrapper
from .pipeline_wrapper import PipelineSharedModuleWrapper from .pipeline_wrapper import PipelineSharedModuleWrapper
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper'] __all__ = ['PipelineSharedModuleWrapper']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class LambdaWrapper(nn.Module):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them.
Args:
func (``Callable``): User customed function.
layers_cfg (dict, optional): Config of layers, defaults to None.
"""
def __init__(self, func, layers_cfg: dict = None):
super().__init__()
self.func = func
self.layers = self._build_layers(layers_cfg)
def _build_layers(self, layers_cfg: dict):
if layers_cfg is None:
return None
else:
layers = []
for cfg in layers_cfg:
layer = build_layer(cfg)
layers.append(layer)
return layers
def forward(self, *args, **kwargs):
return self.func(self, *args, **kwargs)
from .model_from_config import ModelFromConfig
__all__ = ['ModelFromConfig']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.builder import build_layer
class ModelFromConfig(nn.Module, ABC):
def __init__(self):
super(ModelFromConfig, self).__init__()
self.layers = nn.ModuleList()
self.layers_cfg = []
def build_from_cfg(self, start=None, end=None):
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
'spelling and if you have initialized this variable'
if start is None:
start = 0
if end is None:
end = len(self.layers_cfg)
for cfg in self.layers_cfg[start: end]:
layer = build_layer(cfg)
self.layers.append(layer)
@abstractmethod
def init_weights(self):
pass
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
from .pipelinable import PipelinableContext, PipelinableModel
from .layer_sepc import LayerSpec
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
\ No newline at end of file
import torch
from colossalai.utils.model.utils import call_to_str
class LayerSpec:
"""
"""
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
self.module_kwargs = module_kwargs
self.children = None
self._param_count = 0
if not issubclass(typename, torch.nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
def __repr__(self):
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
@property
def param_count(self):
return self._param_count
def build(self):
"""Build the stored specification."""
recovered_args = []
for obj in self.module_args:
if isinstance(obj, LayerSpec):
obj = obj.build()
recovered_args.append(obj)
recovered_args = tuple(recovered_args)
recovered_kwargs = {}
for k, v in self.module_kwargs.items():
if isinstance(v, LayerSpec):
v = v.build()
recovered_kwargs[k] = v
return self.typename(*recovered_args, **recovered_kwargs)
def set_children(self, children):
self.children = children
def count_params(self):
self._param_count = 0
layer = self.build()
for param in layer.parameters():
self._param_count += param.numel()
return self._param_count
def reset_param_count(self):
self._param_count = 0
\ No newline at end of file
import torch import torch
import inspect import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.builder.pipeline import partition_uniform, partition_balanced from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
from colossalai.nn.layer.utils import CheckpointModule from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoParameter
from .layer_sepc import LayerSpec
class PipelinableContext(InsertPostInitMethodToModuleSubClasses): class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
A context manager to split the model into pipeline stages.
"""
def __init__(self): def __init__(self, policy: str="balanced"):
super().__init__() super().__init__()
self._layer_spec_dict = {} self._layer_spec_dict = {}
self._root_children = None self._root_children = None
self._model = None self._model = None
self._layer_spec_list = [] self._layer_spec_list = []
self._func_dict = {} self._func_dict = {}
self._policy = "balanced" self._policy = policy
@property @property
def policy(self): def policy(self):
return self._policy return self._policy
@policy.setter
def policy(self, policy: str):
self._policy = policy
@property @property
def layers_count(self): def layers_count(self):
return len(self._layer_spec_list) return len(self._layer_spec_list)
...@@ -30,10 +38,9 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -30,10 +38,9 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
return len(self._func_dict) return len(self._func_dict)
def _pre_context_exec(self): def _pre_context_exec(self):
""" """
The Callback function when entering the context The Callback function when entering the context
""" """
# reserve rng states # reserve rng states
self.cpu_rng_state = torch.get_rng_state() self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state() self.cuda_rng_state = torch.cuda.get_rng_state()
...@@ -52,35 +59,50 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -52,35 +59,50 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times. NOTE() The module may be passed to this function multiple times.
""" """
module_id = id(module) # iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = [] modified_args = []
for obj in args: for arg in args:
if issubclass(obj.__class__, torch.nn.modules.module.Module): if isinstance(arg, torch.nn.Module):
obj = self._layer_spec_dict[id(obj)] arg = self._layer_spec_dict[id(arg)]
modified_args.append(obj) modified_args.append(arg)
# to the same for the keyword arguments
modified_kwargs = {} modified_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
if issubclass(v.__class__, torch.nn.modules.module.Module): if isinstance(v, torch.nn.Module):
v = self._layer_spec_dict[id(v)] v = self._layer_spec_dict[id(v)]
# (lyl)TODO: analyse ColoTensor as well # (lyl)TODO: analyse ColoTensor as well
modified_kwargs[k] = v modified_kwargs[k] = v
modified_args = tuple(modified_args) # keep track of the module children
# as torch.nn.Module.__init__ is called from inner module to outer module,
# the final value of self._model will be the outermost model
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
# will be the ``ResNet`` object.
self._root_children = list(module.children()) self._root_children = list(module.children())
self._model = module self._model = module
# store the children to keep the module hierarchy
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs) layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
layer_spec.set_children(module.children()) layer_spec.set_children(module.children())
# store the layer spec in this context
module_id = id(module)
self._layer_spec_dict[module_id] = layer_spec self._layer_spec_dict[module_id] = layer_spec
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
name_list = [] name_list = []
for name, param in module.named_parameters(): for name, param in module.named_parameters():
if isinstance(param, ColoTensor): if isinstance(param, ColoParameter):
continue continue
name_list.append((name, param)) name_list.append((name, param))
for name, param in name_list: for name, param in name_list:
delattr(module, name) delattr(module, name)
setattr(module, name, ColoTensor.from_torch_tensor(param)) setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
def to_layer_list(self, exec_seq=None): def to_layer_list(self, exec_seq=None):
""" """
...@@ -100,7 +122,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -100,7 +122,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
if id(module) == id(child_in_container): if id(module) == id(child_in_container):
children_name.append(name) children_name.append(name)
break break
else: else:
self._layer_spec_list.append(layer_spec) self._layer_spec_list.append(layer_spec)
for name, module in self._model.named_modules(): for name, module in self._model.named_modules():
...@@ -110,10 +131,16 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -110,10 +131,16 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
else: else:
front_funcs_list = [] front_funcs_list = []
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq): for index, element in enumerate(exec_seq):
if isinstance(element, str): if isinstance(element, str):
module = dict(self._model.named_modules())[element] assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
# get the layer spec based on the module ID
module = named_modules[element]
layer_spec = self._layer_spec_dict[id(module)] layer_spec = self._layer_spec_dict[id(module)]
# check whether there are functions which should be executed before this module
if len(front_funcs_list) != 0: if len(front_funcs_list) != 0:
func_key = (layer_spec, "front") func_key = (layer_spec, "front")
if func_key not in self._func_dict: if func_key not in self._func_dict:
...@@ -121,6 +148,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -121,6 +148,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
for f in front_funcs_list: for f in front_funcs_list:
self._func_dict[func_key].append(f) self._func_dict[func_key].append(f)
front_funcs_list = [] front_funcs_list = []
func_key = (layer_spec, "behind") func_key = (layer_spec, "behind")
self._layer_spec_list.append(layer_spec) self._layer_spec_list.append(layer_spec)
elif isinstance(element, tuple) and element[1] == "front": elif isinstance(element, tuple) and element[1] == "front":
...@@ -172,70 +200,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): ...@@ -172,70 +200,6 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
return pipeline_model return pipeline_model
def load_policy(self, policy):
self._policy = policy
def _build_kwargs_for_module(function, kw_dict):
"""
Generally, the first argument of module.forward is an input tensor come from the previous layer.
Therefore, we just filter the kwargs from second element of the dictionary.
"""
sig = inspect.signature(function)
if len(sig.parameters) <= 1:
return None
args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
return kw_dict
def _build_kwargs_for_function(function, kw_dict):
sig = inspect.signature(function)
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
if len(kw_dict) == 0:
return None
return kw_dict
def _exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
"""
We suppose the callable object passed to to_layer_list method in two purpose:
a. use the callable object to modify input tensor, such as \
lambda x: torch.flatten(x, 1)
b. use the callable object to modify kwargs value, such as \
def foo(attention_mask=None):
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
return attention_mask
"""
if kw_dict is not None:
rst = func(**kw_dict)
if isinstance(rst, tuple):
for i, k in enumerate(kw_dict.keys()):
kwargs[k] = rst[i]
else:
for k in kw_dict.keys():
kwargs[k] = rst
return input_tensor
return func(input_tensor)
def _exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
assert func_key in func_dict, f"{func_key} is not in the function_dict."
funcs_to_exec = func_dict[func_key]
if isinstance(funcs_to_exec, list):
for f in funcs_to_exec:
f_kwargs = _build_kwargs_for_function(f, kwargs)
input_tensor = _exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
else:
f_kwargs = _build_kwargs_for_function(funcs_to_exec, kwargs)
input_tensor = _exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
return input_tensor
class PipelinableModel(torch.nn.Module): class PipelinableModel(torch.nn.Module):
...@@ -250,16 +214,16 @@ class PipelinableModel(torch.nn.Module): ...@@ -250,16 +214,16 @@ class PipelinableModel(torch.nn.Module):
for module in self._module_list: for module in self._module_list:
if id(module) in self._front_func_dict: if id(module) in self._front_func_dict:
input_tensor = _exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
if isinstance(module, CheckpointModule): if isinstance(module, CheckpointModule):
forward_func = module._forward forward_func = module._forward
else: else:
forward_func = module.forward forward_func = module.forward
if input_tensor is None: if input_tensor is None:
module_kwargs = _build_kwargs_for_function(forward_func, kwargs) module_kwargs = build_kwargs_for_function(forward_func, kwargs)
else: else:
module_kwargs = _build_kwargs_for_module(forward_func, kwargs) module_kwargs = build_kwargs_for_module(forward_func, kwargs)
if module_kwargs is not None and input_tensor is not None: if module_kwargs is not None and input_tensor is not None:
if isinstance(module, CheckpointModule): if isinstance(module, CheckpointModule):
convert_kwargs_to_args = [] convert_kwargs_to_args = []
...@@ -288,57 +252,9 @@ class PipelinableModel(torch.nn.Module): ...@@ -288,57 +252,9 @@ class PipelinableModel(torch.nn.Module):
input_tensor = module(input_tensor) input_tensor = module(input_tensor)
if id(module) in self._behind_func_dict: if id(module) in self._behind_func_dict:
input_tensor = _exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
return input_tensor return input_tensor
class LayerSpec:
def __init__(self, typename, *module_args, **module_kwargs):
self.typename = typename
self.module_args = module_args
self.module_kwargs = module_kwargs
self.children = None
self._param_count = 0
if not issubclass(typename, torch.nn.Module):
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
def __repr__(self):
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
@property
def param_count(self):
return self._param_count
def build(self):
"""Build the stored specification."""
recovered_args = []
for obj in self.module_args:
if isinstance(obj, LayerSpec):
obj = obj.build()
recovered_args.append(obj)
recovered_args = tuple(recovered_args)
recovered_kwargs = {}
for k, v in self.module_kwargs.items():
if isinstance(v, LayerSpec):
v = v.build()
recovered_kwargs[k] = v
return self.typename(*recovered_args, **recovered_kwargs)
def set_children(self, children):
self.children = children
def count_params(self):
self._param_count = 0
layer = self.build()
for param in layer.parameters():
self._param_count += param.numel()
return self._param_count
def reset_param_count(self):
self._param_count = 0
import copy
import heapq import heapq
import inspect
from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
import torch.nn as nn from typing import List
def _binary_partition(weights: List, start: int, end: int):
def _binary_partition(weights, st, ed):
"""Returns the binary partition position of `weights`, given the start """Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`. position `st` and the end position `ed`.
Args: Args:
weights (list): A python list to be binary partitioned weights (list): A python list to be binary partitioned
st (int): the start position of the binary partition start (int): the start position of the binary partition
ed (int): the end position of the binary partition end (int): the end position of the binary partition
Returns: Returns:
int: the binary partition position of `weights` int: the binary partition position of `weights`
""" """
w_sum = weights[ed - 1] w_sum = weights[end - 1]
prefix = 0 prefix = 0
if st > 0: if start > 0:
w_sum -= weights[st - 1] w_sum -= weights[start - 1]
prefix = weights[st - 1] prefix = weights[start - 1]
minimum = float("inf") minimum = float("inf")
for idx in range(st + 1, ed): for idx in range(start + 1, end):
front = weights[idx - 1] - prefix front = weights[idx - 1] - prefix
diff = abs(w_sum - 2 * front) diff = abs(w_sum - 2 * front)
if diff < minimum: if diff < minimum:
pos = idx pos = idx
minimum = diff minimum = diff
return st, pos, ed return start, pos, end
def _heap_addition(weights, intervals, add_cnt): def _heap_addition(weights: List, intervals: int, add_cnt: int):
""" """
""" """
...@@ -150,117 +146,62 @@ def partition_balanced(weights, pipeline_parallel_size, num_chunks): ...@@ -150,117 +146,62 @@ def partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts return parts
def count_layer_params(layers): def build_kwargs_for_module(function, kw_dict):
"""Count the number of parameters in each layer
""" """
param_counts = [0] * len(layers) Generally, the first argument of module.forward is an input tensor come from the previous layer.
for idx, cfg in enumerate(layers): Therefore, we just filter the kwargs from second element of the dictionary.
layer = build_layer(cfg) """
params = filter(lambda p: p.requires_grad, layer.parameters()) sig = inspect.signature(function)
param_counts[idx] = sum(p.numel() for p in params) if len(sig.parameters) <= 1:
return None
return param_counts args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
return kw_dict
def build_pipeline_model_from_cfg(config,
num_chunks: int = 1,
partition_method: str = 'parameter',
verbose: bool = False):
"""An initializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should
inherit colossalai.nn.model.ModelFromConfig to allow this initializer to build model from a sequence
of layer configurations.
:: def build_kwargs_for_function(function, kw_dict):
sig = inspect.signature(function)
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
if len(kw_dict) == 0:
return None
return kw_dict
model_config = dict(
type='VisionTransformerFromConfig',
embedding_cfg=dict(...),
...
)
Args: def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
config (dict): Configuration of the model. """
num_chunks (int, optional): The number of chunks you want to have on the current stage. We suppose the callable object passed to to_layer_list method in two purpose:
This value should be 1 in most cases unless you are using virtual pipeline parallelism. a. use the callable object to modify input tensor, such as \
partition_method (str, optional): This parameter determines how you want to split your model lambda x: torch.flatten(x, 1)
layers into stages, you can set it as 'layer' or 'parameter'. b. use the callable object to modify kwargs value, such as \
verbose (bool, optional): Whether to print the logs. def foo(attention_mask=None):
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
return attention_mask
""" """
ori_model = build_model(config)
layers = ori_model.layers_cfg
layer_length = len(layers)
logger = get_dist_logger()
if verbose:
logger.info(f"The total length of layers is {layer_length}", ranks=[0])
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
method = partition_method.lower()
# Make a partition
if method == 'layer':
num_layers = len(layers)
parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
elif method == 'parameter':
param_counts = count_layer_params(layers)
# print_rank_0(param_counts)
parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
else:
raise ValueError("Method should be a pre-set string in [layer, parameter]")
# Display the partition
if verbose:
log_str = 'Layer allocation after partitioning: \n'
for stage in range(pipeline_parallel_size):
num_layers = 0
for st, ed in parts[stage]:
num_layers += ed - st
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in parts[stage]:
for idx, layer in enumerate(layers[st:ed]):
log_str += f'\t{idx + st:2d}: {layer}\n'
logger.info(log_str, ranks=[0])
# Save the partition if kw_dict is not None:
interval = parts[pipeline_rank] rst = func(**kw_dict)
if isinstance(rst, tuple):
for i, k in enumerate(kw_dict.keys()):
kwargs[k] = rst[i]
else:
for k in kw_dict.keys():
kwargs[k] = rst
return input_tensor
return func(input_tensor)
models = []
for st, ed in interval:
model = copy.deepcopy(ori_model)
model.build_from_cfg(st, ed)
models.append(model)
return nn.ModuleList(models) if len(models) > 1 else models[0] def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
assert func_key in func_dict, f"{func_key} is not in the function_dict."
funcs_to_exec = func_dict[func_key]
if isinstance(funcs_to_exec, list):
for f in funcs_to_exec:
f_kwargs = build_kwargs_for_function(f, kwargs)
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
else:
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False): return input_tensor
"""An intializer to split the model into different stages for pipeline parallelism. \ No newline at end of file
Note that `layer` must be `torch.nn.Sequential`.
Args:
layers (`torch.nn.Sequential`): Layers of model
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virtual pipeline parallelism.
verbose (bool, optional): Whether to print the logs.
"""
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)]))
if verbose:
logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0])
for rank, part in enumerate(partitions):
log_str = f'===== stage={rank} =====\n'
for chunk, (start, end) in enumerate(part):
log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
...@@ -6,7 +6,6 @@ from pathlib import Path ...@@ -6,7 +6,6 @@ from pathlib import Path
import pytest import pytest
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.builder import build_ophooks
@pytest.mark.cpu @pytest.mark.cpu
......
...@@ -17,11 +17,14 @@ from colossalai.logging import get_dist_logger ...@@ -17,11 +17,14 @@ from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader from colossalai.utils import is_using_pp, get_dataloader
from colossalai.utils.model.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from tqdm import tqdm from tqdm import tqdm
from torchvision.datasets import CIFAR10
from titans.dataloader.cifar10 import build_cifar from torchvision.transforms import transforms
from titans.model.vit import vit_tiny_patch4_32 try:
from titans.model.vit import vit_tiny_patch4_32
except:
pass
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
...@@ -49,7 +52,14 @@ def run_trainer(rank, world_size, port): ...@@ -49,7 +52,14 @@ def run_trainer(rank, world_size, port):
# craete dataloaders # craete dataloaders
root = Path(os.environ['DATA']) root = Path(os.environ['DATA'])
train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, pad_if_needed=True, crop=32, resize=32) transform_train = transforms.Compose([
transforms.RandomCrop(224, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
# create loss function # create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1) criterion = CrossEntropyLoss(label_smoothing=0.1)
......
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils.model.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
...@@ -33,7 +33,7 @@ def run_pipelinable(rank): ...@@ -33,7 +33,7 @@ def run_pipelinable(rank):
model = MLP() model = MLP()
assert pipelinable.policy == "balanced" assert pipelinable.policy == "balanced"
pipelinable.load_policy("uniform") pipelinable.policy = "uniform"
assert pipelinable.policy == "uniform" assert pipelinable.policy == "uniform"
pipelinable.to_layer_list() pipelinable.to_layer_list()
......
from .layers import *
from .resnet import VanillaResNet
from .basic_block import ResNetBasicBlock
from .bottleneck import ResNetBottleneck
from .reslayer import ResLayer
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3
@LAYERS.register_module
class ResNetBasicBlock(nn.Module):
"""Basic ResNet block
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3, conv1x1
@LAYERS.register_module
class ResNetBottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.registry import LAYERS
from .conv import conv1x1
@LAYERS.register_module
class ResLayer(nn.Module):
def __init__(self,
block_type: str,
norm_layer_type: str,
inplanes: int,
planes: int,
blocks: int,
groups: int,
base_width: int,
stride: int = 1,
dilation: int = 1,
dilate: bool = False,
):
super().__init__()
self.block = LAYERS.get_module(block_type)
self.norm_layer = LAYERS.get_module(norm_layer_type)
self.inplanes = inplanes
self.planes = planes
self.blocks = blocks
self.groups = groups
self.dilation = dilation
self.base_width = base_width
self.dilate = dilate
self.stride = stride
self.layer = self._make_layer()
def _make_layer(self):
norm_layer = self.norm_layer
downsample = None
previous_dilation = self.dilation
if self.dilate:
self.dilation *= self.stride
self.stride = 1
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
norm_layer(self.planes * self.block.expansion),
)
layers = []
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = self.planes * self.block.expansion
for _ in range(1, self.blocks):
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment