Unverified Commit 0fedef4f authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

Layer integration (#83)



* integrated parallel layers for ease of building models

* integrated 2.5d layers

* cleaned codes and unit tests

* added log metric by step hook; updated imagenet benchmark; fixed some bugs

* reworked initialization; cleaned codes
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent 5c3843dc
......@@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch, get_default_parser
from colossalai.initialize import launch
from functools import partial
from checks_1d.check_layer_1d import *
......@@ -14,7 +14,7 @@ CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(
size=2,
size=4,
mode='1d'
)
),
......@@ -31,11 +31,6 @@ def check_layer(rank, world_size):
check_linear_col()
check_linear_row()
check_attention()
check_mlp()
check_patch_embedding()
check_embed()
check_head()
gpc.destroy()
torch.cuda.empty_cache()
......@@ -43,7 +38,7 @@ def check_layer(rank, world_size):
@pytest.mark.dist
def test_1d():
world_size = 2
world_size = 4
run_func = partial(check_layer, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
......
......@@ -3,16 +3,16 @@ from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import Linear2D, LayerNorm2D, TransformerSelfAttention2D, TransformerMLP2D, TransformerLayer2D
from colossalai.nn import Linear2D, LayerNorm2D, Classifier2D
from colossalai.utils import get_current_device, print_rank_0
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal, NUM_CLASSES
def check_linear():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
OUTPUT_SIZE = HIDDEN_SIZE
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
......@@ -38,12 +38,13 @@ def check_linear():
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[j]
B = torch.chunk(B_master, DEPTH, dim=-1)[j]
B = torch.chunk(B, DEPTH, dim=-1)[i]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
layer.weight.data.copy_(W)
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
......@@ -56,6 +57,7 @@ def check_linear():
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
# print(f'Rank {gpc.get_global_rank()} A:\n{A}\nRank {gpc.get_global_rank()} W:\n{W}\nRank {gpc.get_global_rank()} b:\n{B}\nRank {gpc.get_global_rank()} C:\n{C}\nRank {gpc.get_global_rank()} out:\n{out}')
check_equal(out, C)
print_rank_0('linear forward: pass')
......@@ -64,8 +66,10 @@ def check_linear():
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
......@@ -78,116 +82,102 @@ def check_linear():
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
def check_layernorm():
def check_classifier():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layernorm = LayerNorm2D(INPUT_SIZE)
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
W = torch.chunk(W, DEPTH, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
# C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
print_rank_0('classifier forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = TransformerSelfAttention2D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
print_rank_0('self attention forward: pass')
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
print_rank_0('classifier backward: pass')
def check_mlp():
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = TransformerMLP2D(
HIDDEN_SIZE,
dropout_prob=0.5,
act_func='gelu',
)
layernorm = LayerNorm2D(INPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
......@@ -197,52 +187,144 @@ def check_mlp():
A = A.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_transformerlayer():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
out = layernorm(A)
layer = TransformerLayer2D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
act_func='gelu',
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
check_equal(out, C)
print_rank_0('layer norm forward: pass')
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
print_rank_0('transformerlayer forward: pass')
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('transformerlayer backward: pass')
# def check_attention():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
# layer = TransformerSelfAttention2D(
# HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
# A = torch.chunk(A, DEPTH, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
# print_rank_0('self attention forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass')
# def check_mlp():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
# layer = TransformerMLP2D(
# HIDDEN_SIZE,
# dropout_prob=0.5,
# act_func='gelu',
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
# A = torch.chunk(A, DEPTH, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# out = layer(A)
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
# print_rank_0('mlp forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass')
# def check_transformerlayer():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
# layer = TransformerLayer2D(HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# act_func='gelu',
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5)
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
# A = torch.chunk(A, DEPTH, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
# print_rank_0('transformerlayer forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('transformerlayer backward: pass')
......@@ -5,7 +5,7 @@ import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
from colossalai.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from .common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH
......
......@@ -7,7 +7,7 @@ DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 8
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) == True
......@@ -6,9 +6,9 @@ import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch, get_default_parser
from checks_2d.check_layer_2d import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
from colossalai.initialize import launch
from checks_2d.check_layer_2d import *
from checks_2d.check_operation_2d import *
from functools import partial
......@@ -32,10 +32,7 @@ def check_operations():
def check_layer():
check_linear()
check_layernorm()
check_attention()
check_mlp()
check_transformerlayer()
check_classifier()
def check_layer_and_operation(rank, world_size):
launch(config=CONFIG,
......@@ -45,7 +42,7 @@ def check_layer_and_operation(rank, world_size):
port=29921,
backend='nccl')
check_operations()
# check_operations()
check_layer()
gpc.destroy()
torch.cuda.empty_cache()
......
import torch
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (Linear2p5D, LayerNorm2p5D, TransformerSelfAttention2p5D, TransformerMLP2p5D,
TransformerLayer2p5D)
from colossalai.nn import Linear2p5D, LayerNorm2p5D, Classifier2p5D
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from .common import *
......@@ -71,8 +71,10 @@ def check_linear():
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
......@@ -92,116 +94,99 @@ def check_linear():
print_rank_0('linear backward: pass')
def check_layernorm():
def check_classifier():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
OUTPUT_SIZE = NUM_CLASSES
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
layernorm = LayerNorm2p5D(
INPUT_SIZE,
dtype=dtype)
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
print_rank_0('classifier forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerSelfAttention2p5D(
HIDDEN_SIZE, NUM_ATTENTION_HEADS,
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('self attention forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier backward: pass')
def check_mlp():
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerMLP2p5D(
HIDDEN_SIZE,
mlp_ratio=1,
dropout_prob=0.5,
act_func='gelu',
dtype=dtype,
)
layernorm = LayerNorm2p5D(
INPUT_SIZE,
dtype=dtype)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
......@@ -211,55 +196,152 @@ def check_mlp():
A = A.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_transformerlayer():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
out = layernorm(A)
layer = TransformerLayer2p5D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
act_func='gelu',
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
dtype=dtype,
)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
check_equal(out, C)
print_rank_0('layer norm forward: pass')
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('transformerlayer forward: pass')
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('transformerlayer backward: pass')
# def check_attention():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
# layer = TransformerSelfAttention2p5D(
# HIDDEN_SIZE, NUM_ATTENTION_HEADS,
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5,
# dtype=dtype,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
# print_rank_0('self attention forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass')
# def check_mlp():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
# layer = TransformerMLP2p5D(
# HIDDEN_SIZE,
# mlp_ratio=1,
# dropout_prob=0.5,
# act_func='gelu',
# dtype=dtype,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# out = layer(A)
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
# print_rank_0('mlp forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass')
# def check_transformerlayer():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
# layer = TransformerLayer2p5D(
# HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# act_func='gelu',
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5,
# dtype=dtype,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
# print_rank_0('transformerlayer forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('transformerlayer backward: pass')
\ No newline at end of file
......@@ -5,7 +5,8 @@ TESSERACT_DEP = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 3
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
\ No newline at end of file
......@@ -4,7 +4,7 @@ import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from checks_2p5d.check_layer_2p5d import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from checks_2p5d.check_layer_2p5d import check_linear, check_layernorm, check_classifier
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
from functools import partial
......@@ -12,7 +12,7 @@ from functools import partial
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=8, mode='2.5d', depth=2),
tensor=dict(size=4, mode='2.5d', depth=1),
),
)
......@@ -26,9 +26,7 @@ def check_operations():
def check_layer():
check_linear()
check_layernorm()
check_attention()
check_mlp()
check_transformerlayer()
check_classifier()
def check_layer_and_operation(rank, world_size):
......@@ -47,7 +45,7 @@ def check_layer_and_operation(rank, world_size):
@pytest.mark.dist
def test_2p5d():
world_size = 8
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
......
import time
import torch
import torch.distributed as dist
from colossalai.communication import all_gather, reduce_scatter, all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist, parse_args
from colossalai.utils import get_current_device, print_rank_0
# ARGS = parse_args()
# size = ARGS.world_size
# rank = ARGS.rank
# init_method = f'tcp://{ARGS.host}:{ARGS.port}'
# dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method)
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
init_dist(CONFIG)
assert dist.get_rank() == gpc.get_global_rank()
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
SIZE = 8
tensor = torch.randn(SIZE)
tensor = tensor.to(get_current_device())
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
time.sleep(1)
# tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
# tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
print_rank_0('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
op.wait()
print_rank_0('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import time
import numpy as np
from colossalai.context.parallel_mode import ParallelMode
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D)
from colossalai.core import global_context
from colossalai.logging import get_dist_logger
from colossalai.registry import LAYERS, LOSSES
from colossalai.utils import get_current_device, print_rank_0
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VanillaClassifier,
VanillaPatchEmbedding)
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.utils import get_current_device, print_rank_0
from .common import *
import torch
def check_linear():
......@@ -32,29 +31,20 @@ def check_linear():
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('Linear3D')(INPUT_SIZE,
OUTPUT_SIZE,
# ParallelMode.PARALLEL_3D_INPUT,
# ParallelMode.PARALLEL_3D_WEIGHT,
dtype=dtype,
bias=True)
# torch.nn.init.zeros_(layer.bias)
# torch.nn.init.ones_(layer.weight)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
layer = layer.to(device)
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
# torch.nn.init.zeros_(layer_master.bias)
# torch.nn.init.ones_(layer_master.weight)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data.transpose(0, 1)
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
layer.weight = torch.nn.Parameter(weight)
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
layer.bias = torch.nn.Parameter(bias)
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
......@@ -67,10 +57,10 @@ def check_linear():
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'linear forward: {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
......@@ -80,9 +70,7 @@ def check_linear():
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
......@@ -90,30 +78,25 @@ def check_linear():
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start),
logger)
print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} linear backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} linear backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
# logger.info(f'\nRank {rank} Master:\n{layer_master.bias.grad}\nRank {rank} True:\n{bias_grad}\nRank {rank} Out:\n{layer.bias.grad}')
logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
......@@ -133,11 +116,7 @@ def check_layernorm():
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
norm = LAYERS.get_module('LayerNorm3D')(INPUT_SIZE,
# ParallelMode.PARALLEL_3D_INPUT,
# ParallelMode.PARALLEL_3D_WEIGHT,
eps=1e-6,
dtype=dtype)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
norm = norm.to(device)
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
norm_master = norm_master.to(device)
......@@ -145,11 +124,11 @@ def check_layernorm():
weight_master = norm_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH)[k]
norm.weight = torch.nn.Parameter(weight)
norm.weight.data.copy_(weight)
bias_master = norm_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[k]
norm.bias = torch.nn.Parameter(bias)
norm.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
......@@ -162,10 +141,11 @@ def check_layernorm():
fwd_start = time.time()
out = norm(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
......@@ -173,14 +153,7 @@ def check_layernorm():
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} layernorm forward: {}'.format(rank,
check_equal(out, C)))
# time.sleep(rank)
# logger.info('Rank {0} master:\n{1}\nRank {0} out:\n{2}\nRank {0} true:\n{3}\n'.
# format(rank,
# C_master.detach().cpu().numpy().tolist(),
# out.detach().cpu().numpy().tolist(),
# C.detach().cpu().numpy().tolist()))
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
......@@ -191,93 +164,32 @@ def check_layernorm():
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0(
'layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} layernorm backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
bias_grad = norm_master.weight.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} layernorm backward (weight_grad): {}'.format(
rank, check_equal(bias_grad, norm.weight.grad)))
logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad)))
bias_grad = norm_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} layernorm backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, norm.bias.grad)))
logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_attention():
def check_classifier():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTSelfAttention3D')(HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
0.,
0.1,
dtype=dtype,
bias=True)
layer = layer.to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH,
SEQ_LENGTH // DEPTH, SEQ_LENGTH // DEPTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
fwd_start = time.time()
out = layer(A)
fwd_end = time.time()
print_rank_0(
'self attention forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
bwd_start = time.time()
out.backward(grad)
bwd_end = time.time()
print_rank_0(
'self attention backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
return fwd_end - fwd_start, bwd_end - bwd_start
def check_mlp():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
......@@ -289,89 +201,19 @@ def check_mlp():
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTMLP3D')(HIDDEN_SIZE,
1,
0.1,
'gelu',
dtype=dtype,
bias=True)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
fwd_end = time.time()
print_rank_0(
'mlp forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
bwd_start = time.time()
out.backward(grad)
bwd_end = time.time()
print_rank_0('mlp backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
return fwd_end - fwd_start, bwd_end - bwd_start
class Testvithead(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0]
x = self.linear(x)
return x
def check_head():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
head = LAYERS.get_module('ViTHead3D')(INPUT_SIZE,
NUM_CLASSES,
dtype=dtype,
bias=True)
# torch.nn.init.zeros_(head.linear.bias)
# torch.nn.init.ones_(head.linear.weight)
head = head.to(device)
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
# torch.nn.init.zeros_(layer.linear.bias)
# torch.nn.init.ones_(layer.linear.weight)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
layer = layer.to(device)
weight_master = layer.linear.weight.data.transpose(0, 1)
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
head.linear.weight = torch.nn.Parameter(weight)
bias_master = layer.linear.bias.data
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
head.linear.bias = torch.nn.Parameter(bias)
layer.bias.data.copy_(bias_master)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
......@@ -383,113 +225,52 @@ def check_head():
A.requires_grad = True
fwd_start = time.time()
out = head(A)
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer(A_master)
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} head forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
# if j == 0:
logger.info('Rank {} head backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
# else:
# logger.info('Rank {} head backward (input_grad): {}'.format(
# # rank, check_equal(A_grad, A.grad)))
# rank,
# A.grad is None))
B_grad = layer.linear.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} head backward (weight_grad): {}'.format(
rank, check_equal(B_grad, head.linear.weight.grad)))
logger.info('Rank {} head backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
bias_grad = layer.linear.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} head backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, head.linear.bias.grad)))
# B_grad = layer.linear.weight.grad.transpose(0, 1)
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
# pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH -
# B_grad.shape[-1])
# B_grad = torch.cat(
# [B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1)
# B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# logger.info('Rank {} head backward (weight_grad): {}'.format(
# rank, check_equal(B_grad, head.linear.weight.grad)))
# if j == k:
# bias_grad = layer.linear.bias.grad
# bias_grad = torch.chunk(bias_grad, DEPTH)[j]
# pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH -
# bias_grad.shape[0], )
# bias_grad = torch.cat(
# [bias_grad,
# torch.zeros(pad_shape, dtype=dtype, device=device)])
# bias_grad = torch.chunk(bias_grad, DEPTH)[i]
# logger.info('Rank {} head backward (bias_grad): {}'.format(
# rank, check_equal(bias_grad, head.linear.bias.grad)))
# else:
# logger.info('Rank {} head backward (bias_grad): {}'.format(
# rank,
# # np.count_nonzero(
# # head.linear.bias.grad.detach().cpu().numpy()) == 0))
# head.linear.bias.grad is None))
return fwd_end - fwd_start, bwd_end - bwd_start
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} head backward (weight_grad): {}'.format(rank,
check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} head backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
bias_grad = layer_master.bias.grad
logger.info('Rank {} head backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
class Testvitembed(torch.nn.Module):
def __init__(self, img_size: int, patch_size: int, in_chans: int,
embed_size: int, drop_prob: float) -> None:
super().__init__()
self.proj = torch.nn.Conv2d(in_chans,
embed_size,
kernel_size=patch_size,
stride=patch_size)
num_patches = (img_size // patch_size)**2
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size))
self.pos_embed = torch.nn.Parameter(
torch.zeros(1, num_patches + 1, embed_size))
self.pos_drop = torch.nn.Dropout(drop_prob)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
return fwd_end - fwd_start, bwd_end - bwd_start
def check_embed():
......@@ -506,21 +287,25 @@ def check_embed():
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
layer = LAYERS.get_module('ViTPatchEmbedding3D')(IMG_SIZE, 4, 3,
HIDDEN_SIZE, 0.)
torch.nn.init.zeros_(layer.proj.bias)
torch.nn.init.ones_(layer.proj.weight)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.)
torch.nn.init.zeros_(layer_master.proj.bias)
torch.nn.init.ones_(layer_master.proj.weight)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
proj_weight_master = layer_master.weight.data
torch.distributed.broadcast(proj_weight_master, src=0)
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[k]
layer.weight.data.copy_(proj_weight)
proj_bias_master = layer_master.bias.data
torch.distributed.broadcast(proj_bias_master, src=0)
proj_bias = torch.chunk(proj_bias_master, DEPTH)[k]
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
......@@ -529,103 +314,55 @@ def check_embed():
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
# out_cls = out[:, 0]
# out_tensor = out[:, 1:]
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
# if j == 0:
# C_cls = C_master[:, 0]
# C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i]
# C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k]
# logger.info('Rank {} embed forward (cls): {}'.format(
# rank, check_equal(out_cls, C_cls)))
# C = C_master[:, 1:]
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
# cls_grad = grad_master[:, 0]
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i]
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k]
# grad = grad_master[:, 1:]
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
# grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1)
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0(
'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
print_rank_0('embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
# A_grad = A_master.grad
# logger.info('Rank {} embed backward (input_grad): {}'.format(
# rank, check_equal(A_grad, A.grad)))
# time.sleep(0.1 * rank)
# logger.info(
# 'Rank {0} master:\n{1}\nRank {0} out:\n{2}\nRank {0} true:\n{3}\n'.
# format(rank,
# A_master.grad.detach().cpu().numpy().tolist(),
# A.grad.detach().cpu().numpy().tolist(),
# A_grad.detach().cpu().numpy().tolist()), ranks=[0])
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
# if j == 0:
logger.info('Rank {} embed backward (cls_grad): {}'.format(
rank, check_equal(cls_grad, layer.cls_token.grad)))
# else:.
# logger.info('Rank {} embed backward (cls_grad): {}'.format(
# rank,
# layer.cls_token.grad is None or np.count_nonzero(
# layer.cls_token.grad.detach().cpu().numpy()) == 0))
logger.info('Rank {} embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
rank, check_equal(pos_grad, layer.pos_embed.grad)))
# if i == 0:
# pos_cls_grad = pos_grad[:, 0]
# pos_tensor_grad = pos_grad[:, 1:]
# pos_tensor_grad = torch.chunk(pos_tensor_grad, DEPTH, dim=1)[j]
# if j == 0:
# logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
# rank,
# check_equal(
# torch.cat(
# (torch.unsqueeze(pos_cls_grad, 1), pos_tensor_grad),
# dim=1), layer.pos_embed.grad)))
# else:
# logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
# rank, check_equal(pos_tensor_grad, layer.pos_embed.grad[:,
# 1:])))
# else:
# logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
# rank, layer.pos_embed.grad is None))
B_grad = layer_master.proj.weight.grad
logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(rank, check_equal(pos_grad, layer.pos_embed.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(
rank, check_equal(B_grad, layer.proj.weight.grad)))
if j == k:
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, check_equal(B_grad,
layer.weight.grad)))
else:
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, layer.weight.grad is None))
bias_grad = layer_master.proj.bias.grad
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.proj.bias.grad)))
logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
......@@ -644,19 +381,15 @@ def check_loss():
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
criterion = LOSSES.get_module('CrossEntropyLoss3D')()
# ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
criterion = CrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ),
dtype=torch.long,
device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=-1)[k]
out = torch.chunk(out, DEPTH, dim=0)[j]
out = out.clone()
out.requires_grad = True
......@@ -665,27 +398,23 @@ def check_loss():
loss = criterion(out, target_master)
fwd_end = time.time()
print_rank_0(
'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), logger)
'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
logger)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} CrossEntropyLoss forward: {}'.format(
rank, check_equal(loss, loss_master)))
logger.info('Rank {} CrossEntropyLoss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time()
loss.backward()
bwd_end = time.time()
print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} CrossEntropyLoss backward: {}'.format(
rank, check_equal(out_grad, out.grad)))
logger.info('Rank {} CrossEntropyLoss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.context import ParallelMode
from colossalai.core import global_context
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.parallel_3d._operation import *
from colossalai.utils import get_current_device
from .common import *
def check_AB():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
dtype = torch.float
j = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[k]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = torch.chunk(B, DEPTH, dim=-1)[i]
B = B.clone()
B.requires_grad = True
out = Matmul_AB_3D.apply(A, B, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT,
ParallelMode.PARALLEL_3D_OUTPUT)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, B_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
# check forward correctness
logger.info('Rank {} AB forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
# check backward correctness
logger.info('Rank {} AB backward (A_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# check backward correctness
logger.info('Rank {} AB backward (B_grad): {}'.format(
rank, check_equal(B_grad, B.grad)))
def check_ABT():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
device = get_current_device()
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
C = C.clone()
C.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[k]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = torch.chunk(B, DEPTH, dim=-1)[i]
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_3D.apply(C, B, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT,
ParallelMode.PARALLEL_3D_WEIGHT,
ParallelMode.PARALLEL_3D_INPUT)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
logger.info('Rank {} ABT forward: {}'.format(rank, check_equal(out, A)))
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
# backward
out.backward(grad)
A_master.backward(grad_master)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[k]
logger.info('Rank {} ABT backward (A_grad): {}'.format(
rank, check_equal(C_grad, C.grad)))
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} ABT backward (B_grad): {}'.format(
rank, check_equal(B_grad, B.grad)))
def check_ATB():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_3D.apply(A, C, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_OUTPUT,
ParallelMode.PARALLEL_3D_WEIGHT)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, DEPTH, dim=0)[k]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = torch.chunk(B, DEPTH, dim=-1)[i]
logger.info('Rank {} ATB forward: {}'.format(rank, check_equal(out, B)))
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[k]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=-1)[i]
out.backward(grad)
B_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} ATB backward (A_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[k]
logger.info('Rank {} ATB backward (B_grad): {}'.format(
rank, check_equal(C_grad, C.grad)))
def check_add():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
device = get_current_device()
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
bias_shape = (HIDDEN_SIZE, )
bias_master = torch.randn(bias_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
bias = torch.chunk(bias, DEPTH)[i]
bias = bias.clone()
bias.requires_grad = True
out = Add_3D.apply(A, bias, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT,
ParallelMode.PARALLEL_3D_OUTPUT)
A_master = A_master.clone()
A_master.requires_grad = True
bias_master = bias_master.clone()
bias_master.requires_grad = True
C_master = A_master + bias_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} Add forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} Add backward (A_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
if j == k:
bias_grad = bias_master.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
logger.info('Rank {} Add backward (b_grad): {}'.format(
rank, check_equal(bias_grad, bias.grad)))
else:
logger.info('Rank {} Add backward (b_grad): {}'.format(
rank,
# np.count_nonzero(bias.grad.detach().cpu().numpy()) == 0))
bias.grad is None))
def check_mul():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
device = get_current_device()
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
bias_shape = (HIDDEN_SIZE, )
bias_master = torch.randn(bias_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
bias = torch.chunk(bias, DEPTH)[i]
bias = bias.clone()
bias.requires_grad = True
out = Mul_3D.apply(A, bias, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT,
ParallelMode.PARALLEL_3D_OUTPUT)
A_master = A_master.clone()
A_master.requires_grad = True
bias_master = bias_master.clone()
bias_master.requires_grad = True
C_master = torch.mul(A_master, bias_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} Mul forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} Mul backward (A_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
if j == k:
bias_grad = bias_master.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
logger.info('Rank {} Mul backward (b_grad): {}'.format(
rank, check_equal(bias_grad, bias.grad)))
else:
logger.info('Rank {} Mul backward (b_grad): {}'.format(
rank,
# np.count_nonzero(bias.grad.detach().cpu().numpy()) == 0))
bias.grad is None))
def check_sum():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
device = get_current_device()
# tensor
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
out_tensor = Sum_3D.apply(A, -1, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = torch.sum(A_master, dim=-1)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} Sum forward: {}'.format(rank,
check_equal(out_tensor, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
out_tensor.backward(grad / DEPTH)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} Sum backward: {}'.format(rank,
check_equal(A_grad, A.grad)))
def check_reduce():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
dtype = torch.float
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
device = get_current_device()
# scaler
B_shape = (DEPTH * DEPTH, DEPTH)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[k]
B = torch.chunk(B, DEPTH, dim=0)[j]
B = torch.squeeze(B)
B = B.clone()
B.requires_grad = True
out_scaler = Reduce_3D.apply(B, 0, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT)
out_scaler = Reduce_3D.apply(out_scaler, 0, DEPTH,
ParallelMode.PARALLEL_3D_INPUT)
out_scaler = Reduce_3D.apply(out_scaler, 0, DEPTH,
ParallelMode.PARALLEL_3D_WEIGHT)
B_master = B_master.clone()
B_master.requires_grad = True
D = torch.sum(B_master)
logger.info('Rank {} Reduce forward: {}'.format(rank,
check_equal(out_scaler,
D)))
grad_shape = D.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
out_scaler.backward(grad_master)
D.backward(grad_master)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.squeeze(B_grad)
logger.info('Rank {} Reduce backward: {}'.format(
rank, check_equal(B_grad, B.grad)))
......@@ -4,12 +4,14 @@
import torch
DEPTH = 2
BATCH_SIZE = 512
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
NUM_CLASSES = 1000
NUM_BLOCKS = 6
IMG_SIZE = 224
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 8
NUM_BLOCKS = 2
IMG_SIZE = 16
def check_equal(A, B):
return torch.allclose(A, B, rtol=1e-4, atol=1e-2)
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert eq
return eq
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.initialize import launch, get_default_parser
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from checks_3d.check_layer_3d import *
from checks_3d.check_operation_3d import *
from colossalai.logging import get_dist_logger
from functools import partial
CONFIG = dict(parallel=dict(pipeline=1, tensor=dict(mode='3d', size=8)),
seed=0)
# def check_operations():
# check_AB()
# check_ABT()
# check_ATB()
# check_add()
# check_mul()
# check_sum()
CONFIG = dict(
parallel=dict(
pipeline=1,
tensor=dict(mode='3d', size=8),
),
seed=42,
)
def check_layer():
logger = get_dist_logger()
liear_fwd_time, linear_bwd_time = check_linear()
norm_fwd_time, norm_bwd_time = check_layernorm()
attn_fwd_time, attn_bwd_time = check_attention()
mlp_fwd_time, mlp_bwd_time = check_mlp()
head_fwd_time, head_bwd_time = check_head()
embed_fwd_time, embed_bwd_time = check_embed()
loss_fwd_time, loss_bwd_time = check_loss()
block_fwd_time = norm_fwd_time + attn_fwd_time + norm_fwd_time + mlp_fwd_time
block_bwd_time = norm_bwd_time + attn_bwd_time + norm_bwd_time + mlp_bwd_time
fwd_time = embed_fwd_time + NUM_BLOCKS * block_fwd_time + norm_fwd_time + head_fwd_time + loss_fwd_time
bwd_time = embed_bwd_time + NUM_BLOCKS * block_bwd_time + norm_bwd_time + head_bwd_time + loss_bwd_time
logger.info('ViT forward time: {:.3f} s | backward time: {:.3f} s'.format(
fwd_time, bwd_time),
ranks=[0])
check_linear()
check_layernorm()
check_classifier()
# check_embed()
# check_loss()
def check_layer_and_operation(rank, world_size):
launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29923,
backend='nccl')
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29923, backend='nccl')
check_layer()
gpc.destroy()
torch.cuda.empty_cache()
......
import colossalai
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from pathlib import Path
from torchvision import transforms
from torch.optim import Adam
import torch.nn as nn
from colossalai.amp.amp_type import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import get_dataloader
from torchvision.models import resnet18
from colossalai.utils import MultiTimer, get_dataloader
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from torchvision.models import resnet18
BATCH_SIZE = 16
IMG_SIZE = 32
......@@ -23,50 +23,32 @@ NUM_EPOCHS = 200
CONFIG = dict(
# Config
fp16=dict(
mode=AMP_TYPE.TORCH
)
)
fp16=dict(mode=AMP_TYPE.TORCH))
def run_trainer_no_pipeline(rank, world_size):
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29930,
backend='nccl'
)
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29930, backend='nccl')
# build model
model = resnet18(num_classes=10)
# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
test_dataset = CIFAR10(
root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
test_dataset = CIFAR10(root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
......@@ -74,38 +56,31 @@ def run_trainer_no_pipeline(rank, world_size):
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader
)
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
logger=logger)
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=100,
display_progress=True,
test_interval=5
)
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=100,
display_progress=True,
test_interval=5)
gpc.destroy()
torch.cuda.empty_cache()
......
import colossalai
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from pathlib import Path
from torchvision import transforms
from torch.optim import Adam
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import get_dataloader
from colossalai.engine.schedule import PipelineSchedule
from torchvision.models import resnet18
from colossalai.utils import MultiTimer, get_dataloader
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from torchvision.models import resnet18
BATCH_SIZE = 16
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(
parallel=dict(
pipeline=2,
),
)
CONFIG = dict(parallel=dict(pipeline=2, ), )
def run_trainer_with_pipeline(rank, world_size):
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29931,
backend='nccl'
)
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29931, backend='nccl')
# build model
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.maxpool,
model.layer1,
model.layer2
)
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
from functools import partial
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(
model.layer3,
model.layer4,
model.avgpool,
Flatten(),
model.fc
)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
test_dataset = CIFAR10(
root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
test_dataset = CIFAR10(root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
......@@ -100,40 +66,32 @@ def run_trainer_with_pipeline(rank, world_size):
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader
)
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
pipe_schedule = PipelineSchedule(num_microbatches=4)
trainer = Trainer(engine=engine,
schedule=pipe_schedule,
logger=logger)
timer = MultiTimer()
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=100,
display_progress=True,
test_interval=5
)
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=100,
display_progress=True,
test_interval=5)
gpc.destroy()
torch.cuda.empty_cache()
......
......@@ -17,60 +17,3 @@ NUM_ATTENTION_HEADS = 8
SUMMA_DIM = 2
NUM_CLASSES = 10
DEPTH = 6
model_cfg = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(
type='ViTInputSplitter2D',
),
embedding_cfg=dict(
type='ViTPatchEmbedding2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(
type='ViTTokenFuser2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='ViTMLP2D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)
......@@ -2,37 +2,30 @@
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.autograd
import torch.multiprocessing as mp
import colossalai
import torch
from colossalai.builder import build_model
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.utils import get_dataloader
from colossalai.nn.layer._parallel_utilities import _gather
from colossalai.nn import CrossEntropyLoss2D
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
from components import *
from functools import partial
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
),
fp16=dict(
mode=None,
),
zero=dict(
level=2
)
)
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
),
fp16=dict(mode=None, ),
zero=dict(level=2))
def train_epoch(engine, train_dataloader):
......@@ -48,31 +41,19 @@ def train_epoch(engine, train_dataloader):
def run_2d_parallel_vision_transformer_level_2(rank, world_size):
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29950,
backend='nccl'
)
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29950, backend='nccl')
# build model
model = build_model(model_cfg)
model.build_from_cfg()
model = vit_lite_depth7_patch4_32(tensor_parallel='2d')
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
......@@ -81,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size):
# build optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss2D()
criterion = CrossEntropyLoss(tensor_parallel='2d')
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
......
......@@ -2,38 +2,30 @@
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.autograd
import torch.multiprocessing as mp
import colossalai
import torch
from colossalai.core import global_context as gpc
from colossalai.builder import build_model
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.utils import get_dataloader
from colossalai.nn.layer._parallel_utilities import _gather
from colossalai.nn import CrossEntropyLoss2D
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from components import *
from components import *
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
),
fp16=dict(
mode=None,
),
zero=dict(
level=3
)
)
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
),
fp16=dict(mode=None, ),
zero=dict(level=3))
def train_epoch(engine, train_dataloader):
......@@ -49,31 +41,19 @@ def train_epoch(engine, train_dataloader):
def run_2d_parallel_vision_transformer_level_3(rank, world_size):
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29951,
backend='nccl'
)
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=29951, backend='nccl')
# build model
model = build_model(model_cfg)
model.build_from_cfg()
model = vit_lite_depth7_patch4_32(tensor_parallel='2d')
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
......@@ -82,7 +62,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
# build optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss2D()
criterion = CrossEntropyLoss(tensor_parallel='2d')
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
......@@ -108,6 +88,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
@pytest.mark.dist
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
def test_3d_vit_zero_level_3():
world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment