Unverified Commit 01a80cd8 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

Hotfix/Colossalai layers (#92)



* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent 0fedef4f
......@@ -2,7 +2,7 @@ BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 200
......
......@@ -72,13 +72,11 @@ def train_cifar():
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
model = vit_lite_depth7_patch4_32()
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
......@@ -107,7 +105,7 @@ def train_cifar():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
......
......@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300
......
......@@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
model = vit_small_patch16_224(num_classes=100, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
......@@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
......
......@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300
......
......@@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
model = vit_small_patch16_224(num_classes=1000, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
......@@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
......
......@@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
# intializer
INITIALIZER_MAPPING = {
......@@ -16,6 +17,9 @@ INITIALIZER_MAPPING = {
'sequence': 'Initializer_Sequence'
}
# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'
# 2D paralllel
SUMMA_DIM = 'SUMMA_DIM'
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
from typing import Union
import numpy as np
import torch
import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.context.config import Config
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
......@@ -386,6 +387,7 @@ class ParallelContext:
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
self.check_sanity()
pg_init = []
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch.distributed as dist
from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D
@DIST_GROUP_INITIALIZER.register_module
......@@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''
for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
......
......@@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
......@@ -59,7 +59,11 @@ class BaseSchedule(ABC):
else:
data, label = batch_data
data, label = self._to_list(data), self._to_list(label)
if isinstance(label, (tuple, list)):
self.batch_size = label[0].size(0)
else:
self.batch_size = label.size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
return self._move_to_device(data), self._move_to_device(label)
def pre_processing(self, engine: Engine):
......
from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import *
from ._utils import split_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
def split_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, (tuple, list)):
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_
from contextlib import nullcontext
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module):
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__()
self.tensor_parallel = get_tensor_parallel_mode()
if self.tensor_parallel == '1d':
self.drop = Dropout1D(p, inplace)
else:
self.drop = nn.Dropout(p, inplace)
def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
with cm:
return self.drop(*args)
......@@ -3,51 +3,19 @@ from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from torch.nn.modules.activation import *
from torch.nn.modules.adaptive import *
from torch.nn.modules.batchnorm import *
from torch.nn.modules.channelshuffle import *
from torch.nn.modules.conv import *
from torch.nn.modules.distance import *
from torch.nn.modules.dropout import *
from torch.nn.modules.flatten import *
from torch.nn.modules.fold import *
from torch.nn.modules.instancenorm import *
from torch.nn.modules.linear import *
from torch.nn.modules.normalization import *
from torch.nn.modules.padding import *
from torch.nn.modules.pixelshuffle import *
from torch.nn.modules.pooling import *
from torch.nn.modules.rnn import *
from torch.nn.modules.sparse import *
from torch.nn.modules.transformer import *
from torch.nn.modules.upsampling import *
from .. import init as init
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
from .vanilla import *
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': VanillaClassifier,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_embedding = {'3d': Embedding3D}
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
_parallel_patchembedding = {
None: VanillaPatchEmbedding,
'None': VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
......@@ -55,65 +23,6 @@ _parallel_patchembedding = {
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None,
**kwargs) -> None:
super().__init__()
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
......@@ -121,11 +30,11 @@ class Embedding(nn.Module):
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
tensor_parallel: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
......@@ -163,9 +72,9 @@ class PatchEmbedding(nn.Module):
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
tensor_parallel: Optional[str] = None) -> None:
position_embed_initializer: Callable = init.zeros_()) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
......@@ -196,36 +105,3 @@ class PatchEmbedding(nn.Module):
def forward(self, *args):
return self.embed(*args)
class Classifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
import math
from typing import Callable, Optional
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
'None': VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class Classifier(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
) -> None:
super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
from typing import Optional
from colossalai.utils import get_current_device
from torch import nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
from .layers import Linear1D_Col, Linear1D_Row
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch
import torch.distributed as dist
from colossalai.constants import PARALLEL_INPUT_1D
from colossalai.core import global_context as gpc
from .._common_utils import divide
from ..utils import divide
def set_parallel_input(input_parallel: bool):
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
def get_parallel_input():
return bool(os.environ[PARALLEL_INPUT_1D])
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment