Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
......@@ -11,7 +11,7 @@ from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
......@@ -36,8 +36,9 @@ class Linear2D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False
):
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'):
super().__init__()
self.in_features = in_features
......@@ -72,31 +73,45 @@ class Linear2D(ParallelLayer):
self.register_parameter('bias', None)
# initialize parameters
self.reset_parameters()
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def reset_parameters(self) -> None:
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
......@@ -192,28 +207,19 @@ class LayerNorm2D(ParallelLayer):
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if self.row_rank == 0:
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
else:
self.gamma = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self.beta = Parameter(torch.tensor(
1.0,
requires_grad=True,
**factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.gamma)
set_tensor_parallel_attribute(self.beta)
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
......
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Sum_2p5D, Add_Bias_2p5D
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
from ._vit import (ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D,
ViTInputSplitter2p5D)
from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
from .layers import Linear2p5D, LayerNorm2p5D
__all__ = [
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Sum_2p5D', 'Add_Bias_2p5D',
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D',
......
This diff is collapsed.
This diff is collapsed.
......@@ -3,7 +3,8 @@
import os
from colossalai.constants import DEPTH_3D
from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
......@@ -23,6 +24,10 @@ def get_depth_from_env() -> int:
)
def get_parallel_mode_from_env(group):
return getattr(ParallelMode, os.environ[group])
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
......@@ -41,6 +46,11 @@ def get_last_group(a, b):
return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group():
os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \
os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
def dbg_check_shape(tensor: Tensor, shape: tuple):
rank = gpc.get_global_rank()
if rank == 0:
......
This diff is collapsed.
This diff is collapsed.
from .layers import ViTBlock
__all__ = ['ViTBlock']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return x
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment