Commit 404ecbdc authored by zbian's avatar zbian
Browse files

Migrated project

parent 2ebaefc5
import os.path as osp
import pytest
import torch
from torch.utils.data import DataLoader
from colossalai.builder import build_dataset, ModelInitializer
from colossalai.core import global_context
from colossalai.initialize import init_dist
from colossalai.logging import get_global_dist_logger
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_partition():
init_dist(CONFIG_PATH)
logger = get_global_dist_logger()
logger.info('finished initialization')
# build model
model = ModelInitializer(global_context.config.model, 1, verbose=True).model_initialize()
logger.info('model is created')
dataset = build_dataset(global_context.config.train_data.dataset)
dataloader = DataLoader(dataset=dataset, **global_context.config.train_data.dataloader)
logger.info('train data is created')
global_context.destroy()
torch.cuda.synchronize()
logger.info('training finished')
if __name__ == '__main__':
test_partition()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_schedule():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
logger = get_global_dist_logger()
schedule.zero_grad()
output, label, losses = schedule.forward_backward_step(forward_only=False)
schedule.step()
logger.info('losses: {}'.format([loss.item() for loss in losses]))
gpc.destroy()
logger.info('training finished')
if __name__ == '__main__':
test_schedule()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
import torch
from colossalai import initialize
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
def run_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
outputs, labels, loss = engine.step()
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine pipeline finished')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_engine():
run_pipeline(PIPE_CONFIG_PATH)
if __name__ == '__main__':
test_engine()
import os
from pathlib import Path
from colossalai.engine import AMP_TYPE
BATCH_SIZE = 512
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
train_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
transform_pipeline=[
dict(type='RandomCrop', size=IMG_SIZE, padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
shuffle=True
)
)
test_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
train=False,
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
),
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
shuffle=True
)
)
optimizer = dict(
type='Adam',
lr=0.001,
weight_decay=0
)
loss = dict(
type='CrossEntropyLoss2D',
)
model = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(
type='ViTInputSplitter2D',
),
embedding_cfg=dict(
type='ViTPatchEmbedding2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(
type='ViTTokenFuser2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='ViTMLP2D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
)
fp16 = dict(
mode=AMP_TYPE.PARALLEL,
initial_scale=2 ** 4
)
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
num_epochs = 60
#!/usr/bin/env sh
test_file=$1
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from pathlib import Path
import pytest
import torch.autograd
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
def eval(engine):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
output[0],
ParallelMode.PARALLEL_2D_ROW,
1
)
output = _gather(
output,
ParallelMode.PARALLEL_2D_COL,
0,
)
output = torch.argmax(output, dim=-1)
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / engine.schedule.num_steps
return correct_sum, total_sum, avg_loss
def train(engine):
engine.train()
accumulated_loss = 0
for i in range(engine.schedule.num_steps):
output, label, loss = engine.step()
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
avg_loss = accumulated_loss / engine.schedule.num_steps
return avg_loss
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize(
CONFIG_PATH)
logger = get_global_dist_logger()
engine = Engine(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
if __name__ == '__main__':
test_2d_parallel_vision_transformer()
#!/usr/bin/env sh
test_file=$1
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from test_layer import check_linear_col, check_linear_row
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(
size=2,
mode='1d'
)
),
)
def check_layer():
check_linear_col()
check_linear_row()
# check_attention()
# check_mlp()
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d():
init_dist(config=CONFIG)
gpc.set_seed()
check_layer()
gpc.destroy()
if __name__ == '__main__':
test_2d()
import torch
import torch.distributed as dist
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import Linear1D_Col, Linear1D_Row
# TransformerMLP1D, \
# TransformerSelfAttention1D, TransformerEncoderLayer1D
from colossalai.utils import get_current_device, print_rank_0
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
def check_linear_col():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE, gather_output=True)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
dist.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = C_master.clone()
check_equal(out, C)
print_rank_0('linear_col gather_output forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
dist.broadcast(grad_master, src=0)
grad = grad_master.detach()
out.backward(grad)
C_master.backward(grad)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_col gather_output backward: pass')
def check_linear_row():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, parallel_input=False)
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
W = W.clone()
W.requires_grad = True
B_shape = (INPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
dist.broadcast(B_master, src=0)
B = B_master.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = C_master.clone()
check_equal(out, C)
print_rank_0('linear_row no parallel_input forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
dist.broadcast(grad_master, src=0)
grad = grad_master.detach()
out.backward(grad)
C_master.backward(grad)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_row no parallel_input backward: pass')
#
# def check_attention():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
#
# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
#
# layer = TransformerSelfAttention1D(
# 1,
# HIDDEN_SIZE // NUM_ATTENTION_HEADS,
# HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# 0.5
# )
#
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = A_master.clone()
# A.requires_grad = True
#
# mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
#
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# print_rank_0('self attention forward: pass')
#
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
#
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass')
#
#
# def check_mlp():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
#
# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
#
# layer = TransformerMLP1D(
# HIDDEN_SIZE,
# HIDDEN_SIZE,
# 4.0
# )
#
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = A_master.clone()
# A.requires_grad = True
#
# out = layer(A)
# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# print_rank_0('mlp forward: pass')
#
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
#
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass')
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(
size=4,
mode='2d'
)
),
)
def check_operations():
check_AB()
check_ABT()
check_ATB()
def check_layer():
check_linear()
check_layernorm()
check_attention()
check_mlp()
check_transformerlayer()
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d():
init_dist(config=CONFIG)
gpc.set_seed()
check_operations()
check_layer()
gpc.destroy()
if __name__ == '__main__':
test_2d()
import torch
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import Linear2D, LayerNorm2D, TransformerSelfAttention2D, TransformerMLP2D, TransformerLayer2D
from colossalai.utils import get_current_device, print_rank_0
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
def check_linear():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = Linear2D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
W = torch.chunk(W, DEPTH, dim=-1)[j]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[j]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master) + B_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('linear forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layernorm = LayerNorm2D(INPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = TransformerSelfAttention2D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
print_rank_0('self attention forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
def check_mlp():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = TransformerMLP2D(
HIDDEN_SIZE,
dropout_prob=0.5,
act_func='gelu',
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_transformerlayer():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = TransformerLayer2D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
act_func='gelu',
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
print_rank_0('transformerlayer forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('transformerlayer backward: pass')
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH
def check_AB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH)
out = Matmul_AB_2D.apply(
A, B,
DEPTH,
out_shape,
i, j,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size
)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, B_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
# check forward correctness
check_equal(out, C)
print_rank_0('AB forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
# check backward correctness
check_equal(A_grad, A.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
# check backward correctness
check_equal(B_grad, B.grad)
print_rank_0('AB backward: pass')
def check_ABT():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = C.clone()
C.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_2D.apply(
C, B,
DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH),
i, j,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
check_equal(out, A)
print_rank_0('ABT forward: pass')
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
# backward
out.backward(grad)
A_master.backward(grad_master)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
check_equal(C_grad, C.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
check_equal(B_grad, B.grad)
print_rank_0('ABT backward: pass')
def check_ATB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
dtype = torch.float
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_2D.apply(
A, C,
DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH),
i, j,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size
)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
check_equal(out, B)
print_rank_0('ATB forward: pass')
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
B_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
check_equal(C_grad, C.grad)
print_rank_0('ATB backward: pass')
import torch
TESSERACT_DIM = 2
TESSERACT_DEP = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
#!/bin/bash
python -m torch.distributed.launch test_2p5d.py --nproc_per_node 8 --host $HOST --port 29516 --world_size 8
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=8, mode='2.5d', depth=2),
),
)
def check_operations():
check_AB()
check_ABT()
check_ATB()
def check_layer():
check_linear()
check_layernorm()
check_attention()
check_mlp()
check_transformerlayer()
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d():
init_dist(config=CONFIG)
gpc.set_seed()
check_layer()
check_operations()
gpc.destroy()
if __name__ == '__main__':
test_2p5d()
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (Linear2p5D, LayerNorm2p5D, TransformerSelfAttention2p5D, TransformerMLP2p5D,
TransformerLayer2p5D)
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from common import *
def check_linear():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = Linear2p5D(
INPUT_SIZE,
OUTPUT_SIZE,
dtype=dtype,
skip_bias_add=False)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
bias = layer.bias
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('linear forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layernorm = LayerNorm2p5D(
INPUT_SIZE,
dtype=dtype)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerSelfAttention2p5D(
HIDDEN_SIZE, NUM_ATTENTION_HEADS,
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('self attention forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
def check_mlp():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerMLP2p5D(
HIDDEN_SIZE,
mlp_ratio=1,
dropout_prob=0.5,
act_func='gelu',
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_transformerlayer():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerLayer2p5D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
act_func='gelu',
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('transformerlayer forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('transformerlayer backward: pass')
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \
Matmul_ATB_2p5D
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from common import *
def check_AB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
out = Matmul_AB_2p5D.apply(
A, B,
TESSERACT_DIM, TESSERACT_DEP, out_shape,
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, B_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
# check forward correctness
check_equal(out, C)
print_rank_0('AB forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(A_grad, A.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(B_grad, B.grad)
print_rank_0('AB backward: pass')
def check_ABT():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
C = C.clone()
C.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_2p5D.apply(
C, B,
TESSERACT_DIM, TESSERACT_DEP, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
check_equal(out, A)
print_rank_0('ABT forward: pass')
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
# backward
out.backward(grad)
A_master.backward(grad_master)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(B_grad, B.grad)
print_rank_0('ABT backward: pass')
def check_ATB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_2p5D.apply(
A, C,
TESSERACT_DIM, TESSERACT_DEP, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
check_equal(out, B)
print_rank_0('ATB forward: pass')
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
B_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
print_rank_0('ATB backward: pass')
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 2
BATCH_SIZE = 512
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
NUM_CLASSES = 10
NUM_BLOCKS = 6
IMG_SIZE = 32
def check_equal(A, B):
return torch.allclose(A, B, rtol=1e-5, atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment