Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
NUM_BATCH = 128
NUM_MICRO = 6
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_amp.py')
def run_no_pipeline(config):
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine.train()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
gpc.destroy()
logger.info('Test engine finished')
report_memory_usage("After testing")
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_engine():
run_no_pipeline(NO_PIPE_CONFIG_PATH)
if __name__ == '__main__':
test_engine()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
NUM_BATCH = 128
NUM_MICRO = 6
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
def test_no_pipeline(config):
print('Test no pipeline engine start')
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine.train()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
gpc.destroy()
logger.info('Test engine finished')
report_memory_usage("After testing")
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_engine():
test_no_pipeline(NO_PIPE_CONFIG_PATH)
if __name__ == '__main__':
test_engine()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
NUM_BATCH = 128
NUM_MICRO = 6
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_amp.py')
def test_no_pipeline(config):
print('Test no pipeline engine start')
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine.train()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
gpc.destroy()
logger.info('Test engine finished')
report_memory_usage("After testing")
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_engine():
test_no_pipeline(NO_PIPE_CONFIG_PATH)
if __name__ == '__main__':
test_engine()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
import torch
from colossalai import initialize
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
def run_pipeline(config):
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine.train()
outputs, labels, loss = engine.step(iter(train_dataloader))
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(rank, loss.item()))
gpc.destroy()
logger.info('Test engine pipeline finished')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_engine():
run_pipeline(PIPE_CONFIG_PATH)
if __name__ == '__main__':
test_engine()
#!/usr/bin/env sh
test_file=$1
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from pathlib import Path
import pytest
import torch.autograd
import colossalai
from colossalai.builder import build_lr_scheduler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.nn.layer._parallel_utilities import _gather
CONFIG_PATH = Path(__file__).parent.parent.joinpath('configs/vit_2d.py')
def eval(engine, test_dataloader):
engine.eval()
accumulated_loss = 0
correct_sum = 0
total_sum = 0
num_steps = len(test_dataloader)
data_iter = iter(test_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.detach().cpu().numpy()
output = _gather(
output[0],
ParallelMode.PARALLEL_2D_ROW,
1
)
output = _gather(
output,
ParallelMode.PARALLEL_2D_COL,
0,
)
output = torch.argmax(output, dim=-1)
correct = torch.sum(label[0] == output)
correct_sum += correct
total_sum += label[0].size(0)
avg_loss = accumulated_loss / num_steps
return correct_sum, total_sum, avg_loss
def train(engine, train_dataloader, lr_scheduler):
engine.train()
accumulated_loss = 0
num_steps = len(train_dataloader)
data_iter = iter(train_dataloader)
for i in range(num_steps):
output, label, loss = engine.step(data_iter)
accumulated_loss += loss.squeeze(0).detach().cpu().numpy()
avg_loss = accumulated_loss / num_steps
lr_scheduler.step()
return avg_loss
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d_parallel_vision_transformer():
# init dist
engine, train_dataloader, test_dataloader = colossalai.initialize(CONFIG_PATH)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, engine.optimizer)
logger = get_global_dist_logger()
logger.info('start training')
for epoch in range(gpc.config.num_epochs):
train_loss = train(engine, train_dataloader, lr_scheduler)
logger.info(f'epoch {epoch} - train loss: {train_loss}')
if epoch % 2 == 0:
correct_sum, total_sum, eval_loss = eval(engine, test_dataloader)
logger.info(
f'epoch {epoch} - eval loss: {eval_loss}, total: {total_sum}, '
f'correct: {correct_sum}, acc: {correct_sum / total_sum}')
if __name__ == '__main__':
test_2d_parallel_vision_transformer()
#!/usr/bin/env sh
test_file=$1
python $test_file --local_rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
......@@ -6,8 +6,9 @@ import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
IMG_SIZE = 16
HIDDEN_SIZE = 8
NUM_CLASSES = 10
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
......@@ -4,8 +4,8 @@
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from test_layer import check_linear_col, check_linear_row
from colossalai.initialize import launch, get_default_parser
from test_layer import *
CONFIG = dict(
parallel=dict(
......@@ -19,20 +19,31 @@ CONFIG = dict(
def check_layer():
# print_rank_0('start check_linear_col')
check_linear_col()
check_linear_row()
# check_attention()
# check_mlp()
check_attention()
check_mlp()
check_patch_embedding()
check_embed()
check_head()
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d():
init_dist(config=CONFIG)
gpc.set_seed()
def test_1d():
parser = get_default_parser()
args = parser.parse_args()
launch(config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
check_layer()
gpc.destroy()
if __name__ == '__main__':
test_2d()
test_1d()
from tests.test_layers.test_3d.common import IMG_SIZE
import torch
import torch.distributed as dist
from torch.nn import Parameter
import time
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import Linear1D_Col, Linear1D_Row
# TransformerMLP1D, \
# TransformerSelfAttention1D, TransformerEncoderLayer1D
from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D
from colossalai.utils import get_current_device, print_rank_0
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE
def check_linear_col():
......@@ -142,70 +141,274 @@ def check_linear_row():
print_rank_0('linear_row no parallel_input backward: pass')
#
# def check_attention():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
#
# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
#
# layer = TransformerSelfAttention1D(
# 1,
# HIDDEN_SIZE // NUM_ATTENTION_HEADS,
# HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# 0.5
# )
#
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = A_master.clone()
# A.requires_grad = True
#
# mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
#
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# print_rank_0('self attention forward: pass')
#
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
#
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass')
#
#
# def check_mlp():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
#
# i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
#
# layer = TransformerMLP1D(
# HIDDEN_SIZE,
# HIDDEN_SIZE,
# 4.0
# )
#
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = A_master.clone()
# A.requires_grad = True
#
# out = layer(A)
# assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# print_rank_0('mlp forward: pass')
#
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
#
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass')
class Testvithead(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0]
x = self.linear(x)
return x
def check_head():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
head = ViTHead1D(INPUT_SIZE, NUM_CLASSES, dtype=dtype)
torch.nn.init.zeros_(head.linear.bias)
torch.nn.init.ones_(head.linear.weight)
head = head.to(device)
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
torch.nn.init.zeros_(layer.linear.bias)
torch.nn.init.ones_(layer.linear.weight)
layer = layer.to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
fwd_start = time.time()
out = head(A)
fwd_end = time.time()
print_rank_0(
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer(A_master)
# C = torch.chunk(C_master, DEPTH, dim=0)[i]
print_rank_0('Rank {} head forward: {}'.format(i, check_equal(out, C_master)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
# grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
# bwd_start = time.time()
out.backward(grad_master)
# bwd_end = time.time()
# print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
# logger)
C_master.backward(grad_master)
A_grad = A_master.grad
# if j == 0:
print_rank_0('Rank {} head backward (input_grad): {}'.format(
i, check_equal(A_grad, A.grad)))
class Testvitembed(torch.nn.Module):
def __init__(self, img_size: int, patch_size: int, in_chans: int,
embed_size: int, drop_prob: float) -> None:
super().__init__()
self.proj = torch.nn.Conv2d(in_chans,
embed_size,
kernel_size=patch_size,
stride=patch_size)
num_patches = (img_size // patch_size)**2
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size))
self.pos_embed = torch.nn.Parameter(
torch.zeros(1, num_patches + 1, embed_size))
self.pos_drop = torch.nn.Dropout(drop_prob)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTPatchEmbedding1D(IMG_SIZE, 4, HIDDEN_SIZE)
layer2 = ViTTokenFuser1D(IMG_SIZE, 4, HIDDEN_SIZE)
torch.nn.init.zeros_(layer.proj.bias)
torch.nn.init.ones_(layer.proj.weight)
torch.nn.init.ones_(layer2.cls_token)
torch.nn.init.ones_(layer2.pos_embed)
layer = layer.to(device)
layer2 = layer2.to(device)
layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.)
torch.nn.init.zeros_(layer_master.proj.bias)
torch.nn.init.ones_(layer_master.proj.weight)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer2(layer(A))
fwd_end = time.time()
print_rank_0(
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
# out_cls = out[:, 0]
# out_tensor = out[:, 1:]
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
# if j == 0:
# C_cls = C_master[:, 0]
# C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i]
# C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k]
# logger.info('Rank {} embed forward (cls): {}'.format(
# rank, check_equal(out_cls, C_cls)))
# C = C_master[:, 1:]
print_rank_0('Rank {} embed forward: {}'.format(i, check_equal(out, C_master)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
# cls_grad = grad_master[:, 0]
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i]
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k]
# grad = grad_master[:, 1:]
# grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1)
bwd_start = time.time()
out.backward(grad_master)
bwd_end = time.time()
print_rank_0(
'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start))
C_master.backward(grad_master)
A_grad = A_master.grad
print_rank_0('Rank {} embed backward (input_grad): {}'.format(i, check_equal(A_grad, A.grad)))
print_rank_0('Rank {} embed backward (cls_grad): {}'.format(
i, check_equal(layer_master.cls_token.grad, layer2.cls_token.grad)))
print_rank_0('Rank {} embed backward (pos_embed_grad): {}'.format(
i, check_equal(layer_master.pos_embed.grad, layer2.pos_embed.grad)))
print_rank_0('Rank {} embed backward (proj_weight_grad): {}'.format(
i, check_equal(layer_master.proj.weight.grad, layer.proj.weight.grad)))
print_rank_0('Rank {} embed backward (proj_bias_grad): {}'.format(
i, check_equal(layer_master.proj.bias.grad, layer.proj.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTSelfAttention1D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
0.5,
0.5
).to(device=device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A)
assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
print_rank_0('self attention forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
def check_mlp():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTMLP1D(
HIDDEN_SIZE,
4.0
).to(device=device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_patch_embedding():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = 4
PATCH_SIZE = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTPatchEmbedding1D(
INPUT_SIZE,
PATCH_SIZE,
HIDDEN_SIZE,
).to(device=device)
A_shape = (BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
print('output size: ', out.size())
assert out.shape == (BATCH_SIZE, 4, HIDDEN_SIZE)
print_rank_0('patch embedding forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('patch embedding backward: pass')
......@@ -4,7 +4,7 @@
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from colossalai.initialize import launch, get_default_parser
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
......@@ -36,8 +36,14 @@ def check_layer():
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2d():
init_dist(config=CONFIG)
gpc.set_seed()
parser = get_default_parser()
args = parser.parse_args()
launch(config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
check_operations()
check_layer()
gpc.destroy()
......
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from colossalai.initialize import launch, get_default_parser
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
......@@ -30,8 +30,14 @@ def check_layer():
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d():
init_dist(config=CONFIG)
gpc.set_seed()
parser = get_default_parser()
args = parser.parse_args()
launch(config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
check_layer()
check_operations()
gpc.destroy()
......
......@@ -16,7 +16,7 @@ def check_AB():
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
......@@ -41,11 +41,10 @@ def check_AB():
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
out = Matmul_AB_2p5D.apply(
A, B,
TESSERACT_DIM, TESSERACT_DEP, out_shape,
TESSERACT_DIM, out_shape,
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
......@@ -93,7 +92,7 @@ def check_ABT():
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
......@@ -119,11 +118,10 @@ def check_ABT():
out = Matmul_ABT_2p5D.apply(
C, B,
TESSERACT_DIM, TESSERACT_DEP, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
TESSERACT_DIM, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
......@@ -169,7 +167,7 @@ def check_ATB():
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
dtype = torch.float
......@@ -195,11 +193,10 @@ def check_ATB():
out = Matmul_ATB_2p5D.apply(
A, C,
TESSERACT_DIM, TESSERACT_DEP, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
......
......@@ -7,9 +7,9 @@ DEPTH = 2
BATCH_SIZE = 512
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
NUM_CLASSES = 10
NUM_CLASSES = 1000
NUM_BLOCKS = 6
IMG_SIZE = 32
IMG_SIZE = 224
def check_equal(A, B):
return torch.allclose(A, B, rtol=1e-5, atol=1e-2)
return torch.allclose(A, B, rtol=1e-4, atol=1e-2)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.initialize import init_dist
from colossalai.initialize import launch, get_default_parser
from test_layer import *
from test_operation import *
from colossalai.logging import get_dist_logger
CONFIG = dict(parallel=dict(pipeline=1, tensor=dict(mode='3d', size=8)),
seed=0)
def check_operations():
check_AB()
check_ABT()
check_ATB()
check_add()
check_mul()
check_sum()
# check_pooler()
# def check_operations():
# check_AB()
# check_ABT()
# check_ATB()
# check_add()
# check_mul()
# check_sum()
def check_layer():
logger = get_global_dist_logger()
logger = get_dist_logger()
liear_fwd_time, linear_bwd_time = check_linear()
norm_fwd_time, norm_bwd_time = check_layernorm()
attn_fwd_time, attn_bwd_time = check_attention()
......@@ -40,15 +40,20 @@ def check_layer():
def _test_main():
# init dist
init_dist(CONFIG)
logger = get_global_dist_logger()
parser = get_default_parser()
args = parser.parse_args()
launch(config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
logger = get_dist_logger()
logger.info('Distributed environment is initialzied.', ranks=[0])
global_context.set_seed()
torch.backends.cudnn.benchmark = True
# check operation
check_operations()
# check_operations()
# check layers
check_layer()
......
import time
import torch
import torch.distributed as dist
from colossalai.communication import all_gather, reduce_scatter, all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist, parse_args
from colossalai.utils import get_current_device, print_rank_0
# ARGS = parse_args()
# size = ARGS.world_size
# rank = ARGS.rank
from colossalai.initialize import parse_args
from colossalai.utils import get_current_device
# init_method = f'tcp://{ARGS.host}:{ARGS.port}'
# dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method)
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
init_dist(CONFIG)
ARGS = parse_args()
size = ARGS.world_size
rank = ARGS.local_rank
assert dist.get_rank() == gpc.get_global_rank()
init_method = f'tcp://{ARGS.host}:{ARGS.port}'
dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method)
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
SIZE = 8
tensor = torch.randn(SIZE)
tensor = tensor.to(get_current_device())
dist.all_reduce(tensor)
print('Rank {0}: {1}'.format(rank, tensor.detach().cpu().numpy().tolist()))
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
time.sleep(1)
# tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
# tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
print_rank_0('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
op.wait()
print_rank_0('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
......@@ -7,39 +7,55 @@ import time
import numpy as np
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context
from colossalai.logging import get_global_dist_logger
from colossalai.logging import get_dist_logger
from colossalai.registry import LAYERS, LOSSES
from colossalai.utils import get_current_device, print_rank_0
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from common import *
def check_linear():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('Linear3D')(INPUT_SIZE,
OUTPUT_SIZE,
ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT,
# ParallelMode.PARALLEL_3D_INPUT,
# ParallelMode.PARALLEL_3D_WEIGHT,
dtype=dtype,
bias=True)
torch.nn.init.zeros_(layer.bias)
torch.nn.init.ones_(layer.weight)
# torch.nn.init.zeros_(layer.bias)
# torch.nn.init.ones_(layer.weight)
layer = layer.to(device)
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
torch.nn.init.zeros_(layer_master.bias)
torch.nn.init.ones_(layer_master.weight)
# torch.nn.init.zeros_(layer_master.bias)
# torch.nn.init.ones_(layer_master.weight)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data.transpose(0, 1)
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
layer.weight = torch.nn.Parameter(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
layer.bias = torch.nn.Parameter(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
......@@ -89,45 +105,52 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
if j == k:
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
logger.info('Rank {} linear backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
else:
logger.info('Rank {} linear backward (bias_grad): {}'.format(
rank,
# np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
layer.bias.grad is None))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} linear backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
# logger.info(f'\nRank {rank} Master:\n{layer_master.bias.grad}\nRank {rank} True:\n{bias_grad}\nRank {rank} Out:\n{layer.bias.grad}')
return fwd_end - fwd_start, bwd_end - bwd_start
def check_layernorm():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
norm = LAYERS.get_module('LayerNorm3D')(INPUT_SIZE,
ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT,
# ParallelMode.PARALLEL_3D_INPUT,
# ParallelMode.PARALLEL_3D_WEIGHT,
eps=1e-6,
dtype=dtype)
norm = norm.to(device)
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
norm_master = norm_master.to(device)
weight_master = norm_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH)[k]
norm.weight = torch.nn.Parameter(weight)
bias_master = norm_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[k]
norm.bias = torch.nn.Parameter(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
......@@ -181,29 +204,15 @@ def check_layernorm():
logger.info('Rank {} layernorm backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
if j == k:
bias_grad = norm_master.weight.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(
rank, check_equal(bias_grad, norm.weight.grad)))
else:
logger.info('Rank {} linear backward (weight_grad): {}'.format(
rank,
# np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
norm.weight.grad is None))
if j == k:
bias_grad = norm_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
logger.info('Rank {} linear backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, norm.bias.grad)))
else:
logger.info('Rank {} linear backward (bias_grad): {}'.format(
rank,
# np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
norm.bias.grad is None))
bias_grad = norm_master.weight.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} layernorm backward (weight_grad): {}'.format(
rank, check_equal(bias_grad, norm.weight.grad)))
bias_grad = norm_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} layernorm backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, norm.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
......@@ -211,14 +220,18 @@ def check_layernorm():
def check_attention():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTSelfAttention3D')(HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
......@@ -264,13 +277,17 @@ def check_attention():
def check_mlp():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTMLP3D')(HIDDEN_SIZE,
1,
......@@ -320,28 +337,42 @@ class Testvithead(torch.nn.Module):
def check_head():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
head = LAYERS.get_module('ViTHead3D')(INPUT_SIZE,
NUM_CLASSES,
dtype=dtype,
bias=True)
torch.nn.init.zeros_(head.linear.bias)
torch.nn.init.ones_(head.linear.weight)
# torch.nn.init.zeros_(head.linear.bias)
# torch.nn.init.ones_(head.linear.weight)
head = head.to(device)
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
torch.nn.init.zeros_(layer.linear.bias)
torch.nn.init.ones_(layer.linear.weight)
# torch.nn.init.zeros_(layer.linear.bias)
# torch.nn.init.ones_(layer.linear.weight)
layer = layer.to(device)
weight_master = layer.linear.weight.data.transpose(0, 1)
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
head.linear.weight = torch.nn.Parameter(weight)
bias_master = layer.linear.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
head.linear.bias = torch.nn.Parameter(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
......@@ -397,31 +428,43 @@ def check_head():
B_grad = layer.linear.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH -
B_grad.shape[-1])
B_grad = torch.cat(
[B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1)
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} head backward (weight_grad): {}'.format(
rank, check_equal(B_grad, head.linear.weight.grad)))
if j == k:
bias_grad = layer.linear.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH -
bias_grad.shape[0], )
bias_grad = torch.cat(
[bias_grad,
torch.zeros(pad_shape, dtype=dtype, device=device)])
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
logger.info('Rank {} head backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, head.linear.bias.grad)))
else:
logger.info('Rank {} head backward (bias_grad): {}'.format(
rank,
# np.count_nonzero(
# head.linear.bias.grad.detach().cpu().numpy()) == 0))
head.linear.bias.grad is None))
bias_grad = layer.linear.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} head backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, head.linear.bias.grad)))
# B_grad = layer.linear.weight.grad.transpose(0, 1)
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
# pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH -
# B_grad.shape[-1])
# B_grad = torch.cat(
# [B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1)
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# logger.info('Rank {} head backward (weight_grad): {}'.format(
# rank, check_equal(B_grad, head.linear.weight.grad)))
# if j == k:
# bias_grad = layer.linear.bias.grad
# bias_grad = torch.chunk(bias_grad, DEPTH)[j]
# pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH -
# bias_grad.shape[0], )
# bias_grad = torch.cat(
# [bias_grad,
# torch.zeros(pad_shape, dtype=dtype, device=device)])
# bias_grad = torch.chunk(bias_grad, DEPTH)[i]
# logger.info('Rank {} head backward (bias_grad): {}'.format(
# rank, check_equal(bias_grad, head.linear.bias.grad)))
# else:
# logger.info('Rank {} head backward (bias_grad): {}'.format(
# rank,
# # np.count_nonzero(
# # head.linear.bias.grad.detach().cpu().numpy()) == 0))
# head.linear.bias.grad is None))
return fwd_end - fwd_start, bwd_end - bwd_start
......@@ -452,12 +495,16 @@ class Testvitembed(torch.nn.Module):
def check_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float32
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTPatchEmbedding3D')(IMG_SIZE, 4, 3,
HIDDEN_SIZE, 0.)
......@@ -585,16 +632,20 @@ def check_embed():
def check_loss():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
criterion = LOSSES.get_module('CrossEntropyLoss3D')(
ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
criterion = LOSSES.get_module('CrossEntropyLoss3D')()
# ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
......
......@@ -3,7 +3,7 @@
from colossalai.context import ParallelMode
from colossalai.core import global_context
from colossalai.logging import get_global_dist_logger
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.parallel_3d._operation import *
from colossalai.utils import get_current_device
......@@ -12,7 +12,7 @@ from common import *
def check_AB():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float
j = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
......@@ -83,7 +83,7 @@ def check_AB():
def check_ABT():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
......@@ -152,7 +152,7 @@ def check_ABT():
def check_ATB():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float
......@@ -222,7 +222,7 @@ def check_ATB():
def check_add():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
......@@ -296,7 +296,7 @@ def check_add():
def check_mul():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
......@@ -370,7 +370,7 @@ def check_mul():
def check_sum():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
......@@ -417,7 +417,7 @@ def check_sum():
def check_reduce():
rank = torch.distributed.get_rank()
logger = get_global_dist_logger()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.initialize import init_dist
from colossalai.logging import get_global_dist_logger
from colossalai.initialize import launch, get_default_parser
from colossalai.logging import get_dist_logger
from test_layer import *
CONFIG = dict(
......@@ -19,11 +19,17 @@ def check_layer():
def _test_main():
# init dist
init_dist(CONFIG)
logger = get_global_dist_logger()
parser = get_default_parser()
args = parser.parse_args()
launch(config=CONFIG,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
logger = get_dist_logger()
logger.info('Distributed environment is initialzied.', ranks=[0])
gpc.set_seed()
torch.backends.cudnn.benchmark = True
# check layers
......
# from colossal.components.optimizer.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmupLR, FlatAnnealingLR, FlatAnnealingWarmupLR
# from colossal.components.optimizer.lr_scheduler import LinearWarmupLR
# from colossal.components.optimizer.lr_scheduler import MultiStepLR, MultiStepWarmupLR
# from colossal.components.optimizer.lr_scheduler import OneCycleLR
# from colossal.components.optimizer.lr_scheduler import PolynomialLR, PolynomialWarmupLR
import matplotlib.pyplot as plt
import pytest
from torch.optim import SGD
from torchvision.models import resnet18
from colossalai.builder import build_lr_scheduler
NUM_EPOCHS = 5
NUM_STEPS_PER_EPOCH = 10
cfg = {
'warmup_steps': 5
}
def init_cfg(name, **kwargs):
return {
'type': name,
**cfg,
**kwargs
}
def test_scheduler(optimizer, scheduler_name, **kwargs):
for group in optimizer.param_groups:
group['lr'] = 0.1
config = init_cfg(scheduler_name, **kwargs)
scheduler = build_lr_scheduler(config,
optimizer, NUM_EPOCHS * NUM_STEPS_PER_EPOCH, NUM_STEPS_PER_EPOCH)
x = []
y = []
for epoch in range(NUM_EPOCHS):
for i in range(NUM_STEPS_PER_EPOCH):
step = epoch * NUM_STEPS_PER_EPOCH + i
lr = optimizer.param_groups[0]['lr']
x.append(step)
y.append(lr)
scheduler.step()
print(y)
plt.plot(x, y)
plt.show()
@pytest.mark.skip("This test is skipped as it requires visualization, "
"You can visualize the test output plots on your local environment")
def test():
model = resnet18()
optimizer = SGD(model.parameters(), lr=1.0)
test_scheduler(optimizer, 'CosineAnnealingLR')
test_scheduler(optimizer, 'CosineAnnealingWarmupLR')
test_scheduler(optimizer, 'FlatAnnealingLR')
test_scheduler(optimizer, 'FlatAnnealingWarmupLR')
test_scheduler(optimizer, 'LinearWarmupLR')
test_scheduler(optimizer, 'MultiStepLR', milestones=[1, 3])
test_scheduler(optimizer, 'MultiStepWarmupLR', milestones=[1, 3])
test_scheduler(optimizer, 'MultiStepWarmupLR',
milestones=[1, 3], warmup_epochs=1)
test_scheduler(optimizer, 'PolynomialLR', power=2.0)
test_scheduler(optimizer, 'PolynomialWarmupLR', power=2.0)
test_scheduler(optimizer, 'OneCycleLR')
if __name__ == '__main__':
test()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment