Unverified Commit 01a80cd8 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

Hotfix/Colossalai layers (#92)



* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent 0fedef4f
...@@ -2,7 +2,7 @@ BATCH_SIZE = 512 ...@@ -2,7 +2,7 @@ BATCH_SIZE = 512
LEARNING_RATE = 2e-3 LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2 WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4 TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d' TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 200 NUM_EPOCHS = 200
......
...@@ -72,13 +72,11 @@ def train_cifar(): ...@@ -72,13 +72,11 @@ def train_cifar():
os.mkdir(log_path) os.mkdir(log_path)
logger.log_to_file(log_path) logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode model = vit_lite_depth7_patch4_32()
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size) train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
...@@ -107,7 +105,7 @@ def train_cifar(): ...@@ -107,7 +105,7 @@ def train_cifar():
LogMetricByStepHook(), LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), AccuracyHook(accuracy_func=Accuracy()),
LossHook(), LossHook(),
ThroughputHook(), ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
......
...@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096 ...@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3 LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3 WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4 TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d' TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300 NUM_EPOCHS = 300
......
...@@ -159,14 +159,12 @@ def train_imagenet(): ...@@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path) os.mkdir(log_path)
logger.log_to_file(log_path) logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode model = vit_small_patch16_224(num_classes=100, init_method='jax')
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
...@@ -192,7 +190,7 @@ def train_imagenet(): ...@@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(), LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), AccuracyHook(accuracy_func=Accuracy()),
LossHook(), LossHook(),
ThroughputHook(), ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
......
...@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096 ...@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3 LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3 WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4 TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d' TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300 NUM_EPOCHS = 300
......
...@@ -159,14 +159,12 @@ def train_imagenet(): ...@@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path) os.mkdir(log_path)
logger.log_to_file(log_path) logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode model = vit_small_patch16_224(num_classes=1000, init_method='jax')
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size) train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size) test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp) criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
...@@ -192,7 +190,7 @@ def train_imagenet(): ...@@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(), LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)), AccuracyHook(accuracy_func=Accuracy()),
LossHook(), LossHook(),
ThroughputHook(), ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
# intializer # intializer
INITIALIZER_MAPPING = { INITIALIZER_MAPPING = {
...@@ -16,6 +17,9 @@ INITIALIZER_MAPPING = { ...@@ -16,6 +17,9 @@ INITIALIZER_MAPPING = {
'sequence': 'Initializer_Sequence' 'sequence': 'Initializer_Sequence'
} }
# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'
# 2D paralllel # 2D paralllel
SUMMA_DIM = 'SUMMA_DIM' SUMMA_DIM = 'SUMMA_DIM'
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import random import random
from typing import Union from typing import Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
...@@ -386,6 +387,7 @@ class ParallelContext: ...@@ -386,6 +387,7 @@ class ParallelContext:
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode'] tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
self.check_sanity() self.check_sanity()
pg_init = [] pg_init = []
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
...@@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer): ...@@ -29,6 +30,7 @@ class Initializer_1D(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_1D mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
......
...@@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable ...@@ -10,7 +10,7 @@ from typing import Iterable, Union, List, Callable
from .._base_engine import Engine from .._base_engine import Engine
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC): class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation. """A basic helper class to control the process of training or evaluation.
...@@ -59,7 +59,11 @@ class BaseSchedule(ABC): ...@@ -59,7 +59,11 @@ class BaseSchedule(ABC):
else: else:
data, label = batch_data data, label = batch_data
data, label = self._to_list(data), self._to_list(label) if isinstance(label, (tuple, list)):
self.batch_size = label[0].size(0)
else:
self.batch_size = label.size(0)
data, label = self._to_list(split_batch(data)), self._to_list(split_batch(label))
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
def pre_processing(self, engine: Engine): def pre_processing(self, engine: Engine):
......
from .colossalai_layer import * from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import * from .wrapper import *
from ._utils import split_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
def split_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, (tuple, list)):
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_
from contextlib import nullcontext
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module):
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__()
self.tensor_parallel = get_tensor_parallel_mode()
if self.tensor_parallel == '1d':
self.drop = Dropout1D(p, inplace)
else:
self.drop = nn.Dropout(p, inplace)
def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
with cm:
return self.drop(*args)
...@@ -3,51 +3,19 @@ from typing import Callable, Optional ...@@ -3,51 +3,19 @@ from typing import Callable, Optional
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
from torch.nn.modules.activation import *
from torch.nn.modules.adaptive import *
from torch.nn.modules.batchnorm import *
from torch.nn.modules.channelshuffle import *
from torch.nn.modules.conv import *
from torch.nn.modules.distance import *
from torch.nn.modules.dropout import *
from torch.nn.modules.flatten import *
from torch.nn.modules.fold import *
from torch.nn.modules.instancenorm import *
from torch.nn.modules.linear import *
from torch.nn.modules.normalization import *
from torch.nn.modules.padding import *
from torch.nn.modules.pixelshuffle import *
from torch.nn.modules.pooling import *
from torch.nn.modules.rnn import *
from torch.nn.modules.sparse import *
from torch.nn.modules.transformer import *
from torch.nn.modules.upsampling import *
from .. import init as init from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
from .vanilla import * _parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': VanillaClassifier,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_embedding = {'3d': Embedding3D}
_parallel_patchembedding = { _parallel_patchembedding = {
None: VanillaPatchEmbedding, 'None': VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding, '1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D, '2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D, '2.5d': PatchEmbedding2p5D,
...@@ -55,65 +23,6 @@ _parallel_patchembedding = { ...@@ -55,65 +23,6 @@ _parallel_patchembedding = {
} }
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None,
**kwargs) -> None:
super().__init__()
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
class Embedding(nn.Module): class Embedding(nn.Module):
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
...@@ -121,11 +30,11 @@ class Embedding(nn.Module): ...@@ -121,11 +30,11 @@ class Embedding(nn.Module):
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
tensor_parallel: Optional[str] = None,
*args, *args,
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
if tensor_parallel in [None, '1d']: tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.embed = nn.Embedding(num_embeddings, self.embed = nn.Embedding(num_embeddings,
embedding_dim, embedding_dim,
padding_idx=padding_idx, padding_idx=padding_idx,
...@@ -163,9 +72,9 @@ class PatchEmbedding(nn.Module): ...@@ -163,9 +72,9 @@ class PatchEmbedding(nn.Module):
flatten: bool = True, flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(), position_embed_initializer: Callable = init.zeros_()) -> None:
tensor_parallel: Optional[str] = None) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel]( self.embed = _parallel_patchembedding[tensor_parallel](
img_size, img_size,
patch_size, patch_size,
...@@ -196,36 +105,3 @@ class PatchEmbedding(nn.Module): ...@@ -196,36 +105,3 @@ class PatchEmbedding(nn.Module):
def forward(self, *args): def forward(self, *args):
return self.embed(*args) return self.embed(*args)
class Classifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
import math
from typing import Callable, Optional
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
'None': VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class Classifier(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
) -> None:
super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
from typing import Optional
from colossalai.utils import get_current_device
from torch import nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
\ No newline at end of file
from .layers import Linear1D_Col, Linear1D_Row from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D from .layers import MixedFusedLayerNorm1D as LayerNorm1D
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D'] __all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import PARALLEL_INPUT_1D
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from .._common_utils import divide from ..utils import divide
def set_parallel_input(input_parallel: bool):
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
def get_parallel_input():
return bool(os.environ[PARALLEL_INPUT_1D])
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment