Unverified Commit cd9c28e0 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

added CI for unit testing (#69)

parent 45355a62
name: Build
on:
pull_request:
types: [review_requested]
branches:
- "*"
jobs:
build:
name: Build and test Colossal-AI
runs-on: [self-hosted, gpu]
container:
image: nvcr.io/nvidia/pytorch:21.07-py3
options: --gpus all --rm --ipc=host -v /data/cifar-10:/data/cifar-10
timeout-minutes: 1200
if: github.event.pull_request.draft == false && github.base_ref == 'main' && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
steps:
- name: Setup Environment
run: |
export https_proxy=http://172.17.0.1:7890 http_proxy=http://172.17.0.1:7890 all_proxy=socks5://172.17.0.1:7890
- name: Install dependencies
run: |
python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
python3 -m pip install -U pip setuptools wheel --user
pip install pytest tensorboard deepspeed apex
- uses: actions/checkout@v2
- name: Install Colossal-AI
run: |
pip install -v --no-cache-dir --global-option="--cuda_ext" .
- name: Unit Testing
run: |
pytest tests
env:
DATA: /data/cifar-10
...@@ -5,7 +5,6 @@ import os ...@@ -5,7 +5,6 @@ import os
import os.path as osp import os.path as osp
import torch import torch
from torch.utils.tensorboard import SummaryWriter
from typing import List from typing import List
from decimal import Decimal from decimal import Decimal
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
...@@ -100,6 +99,7 @@ class TensorboardHook(BaseHook): ...@@ -100,6 +99,7 @@ class TensorboardHook(BaseHook):
priority: int = 10, priority: int = 10,
) -> None: ) -> None:
super().__init__(priority=priority) super().__init__(priority=priority)
from torch.utils.tensorboard import SummaryWriter
# create log dir # create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
from pathlib import Path
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai import launch from colossalai import launch
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from functools import partial
from pathlib import Path
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute() CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2d_init.py').absolute()
...@@ -75,6 +75,7 @@ def init_2d(rank, world_size, backend, port, host): ...@@ -75,6 +75,7 @@ def init_2d(rank, world_size, backend, port, host):
check_2d_parallel_rank(rank) check_2d_parallel_rank(rank)
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.cpu @pytest.mark.cpu
...@@ -86,7 +87,7 @@ def test_2d_init(): ...@@ -86,7 +87,7 @@ def test_2d_init():
test_fn = partial(init_2d, test_fn = partial(init_2d,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29500', port='29900',
host='localhost' host='localhost'
) )
mp.spawn(test_fn, nprocs=world_size) mp.spawn(test_fn, nprocs=world_size)
......
...@@ -5,6 +5,7 @@ from functools import partial ...@@ -5,6 +5,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
...@@ -98,6 +99,7 @@ def init_2halfd(rank, world_size, backend, port, host): ...@@ -98,6 +99,7 @@ def init_2halfd(rank, world_size, backend, port, host):
check_tensor_parallel_rank(rank) check_tensor_parallel_rank(rank)
check_2p5d_parallel_rank(rank) check_2p5d_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.cpu @pytest.mark.cpu
...@@ -109,7 +111,7 @@ def test_2halfd_init(): ...@@ -109,7 +111,7 @@ def test_2halfd_init():
test_fn = partial(init_2halfd, test_fn = partial(init_2halfd,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29501', port='29901',
host='localhost' host='localhost'
) )
mp.spawn(test_fn, nprocs=world_size) mp.spawn(test_fn, nprocs=world_size)
......
...@@ -5,8 +5,10 @@ from functools import partial ...@@ -5,8 +5,10 @@ from functools import partial
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
...@@ -90,6 +92,7 @@ def init_3d(rank, world_size, backend, port, host): ...@@ -90,6 +92,7 @@ def init_3d(rank, world_size, backend, port, host):
check_data_parallel_rank(rank) check_data_parallel_rank(rank)
check_pipeline_parallel_rank(rank) check_pipeline_parallel_rank(rank)
gpc.destroy() gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.cpu @pytest.mark.cpu
...@@ -101,7 +104,7 @@ def test_3d_init(): ...@@ -101,7 +104,7 @@ def test_3d_init():
test_fn = partial(init_3d, test_fn = partial(init_3d,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29502', port='29902',
host='localhost' host='localhost'
) )
mp.spawn(test_fn, nprocs=world_size) mp.spawn(test_fn, nprocs=world_size)
......
...@@ -6,7 +6,7 @@ from functools import partial ...@@ -6,7 +6,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch.cuda import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -49,7 +49,7 @@ def run_data_sampler(rank, world_size): ...@@ -49,7 +49,7 @@ def run_data_sampler(rank, world_size):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29503', port='29903',
host='localhost' host='localhost'
) )
colossalai.launch(**dist_args) colossalai.launch(**dist_args)
...@@ -73,6 +73,7 @@ def run_data_sampler(rank, world_size): ...@@ -73,6 +73,7 @@ def run_data_sampler(rank, world_size):
if gpc.get_local_rank(ParallelMode.DATA) != 0: if gpc.get_local_rank(ParallelMode.DATA) != 0:
assert not torch.equal(img, assert not torch.equal(img,
img_to_compare), 'Same image was distributed across ranks but expected it to be different' img_to_compare), 'Same image was distributed across ranks but expected it to be different'
torch.cuda.empty_cache()
@pytest.mark.cpu @pytest.mark.cpu
......
...@@ -6,7 +6,7 @@ from functools import partial ...@@ -6,7 +6,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch.cuda import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torchvision import transforms from torchvision import transforms
...@@ -52,11 +52,10 @@ def run_data_sampler(rank, world_size): ...@@ -52,11 +52,10 @@ def run_data_sampler(rank, world_size):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
backend='gloo', backend='gloo',
port='29499', port='29904',
host='localhost' host='localhost'
) )
colossalai.launch(**dist_args) colossalai.launch(**dist_args)
print('finished initialization')
dataset_cfg = gpc.config.train_data.dataset dataset_cfg = gpc.config.train_data.dataset
dataloader_cfg = gpc.config.train_data.dataloader dataloader_cfg = gpc.config.train_data.dataloader
...@@ -88,6 +87,7 @@ def run_data_sampler(rank, world_size): ...@@ -88,6 +87,7 @@ def run_data_sampler(rank, world_size):
# this should be false if data parallel sampler to given to the dataloader # this should be false if data parallel sampler to given to the dataloader
assert torch.equal(img, assert torch.equal(img,
img_to_compare), 'Same image was distributed across ranks and expected it to be the same' img_to_compare), 'Same image was distributed across ranks and expected it to be the same'
torch.cuda.empty_cache()
@pytest.mark.cpu @pytest.mark.cpu
......
import pytest
from pathlib import Path from pathlib import Path
from colossalai.amp.amp_type import AMP_TYPE from colossalai.amp.amp_type import AMP_TYPE
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
...@@ -34,7 +35,9 @@ CONFIG = dict( ...@@ -34,7 +35,9 @@ CONFIG = dict(
) )
def main(): @pytest.mark.dist
@pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
def test_hybrid_parallel():
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_slurm(config=CONFIG, colossalai.launch_from_slurm(config=CONFIG,
......
#!/usr/bin/env sh
test_file=$1
python $test_file --world_size $SLURM_NPROCS --host $HOST --port 29500 --rank $SLURM_PROCID
\ No newline at end of file
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import os.path as osp import os.path as osp
from pathlib import Path from pathlib import Path
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms from torchvision import transforms
from torch.optim import Adam from torch.optim import Adam
...@@ -15,9 +16,9 @@ from colossalai.core import global_context as gpc ...@@ -15,9 +16,9 @@ from colossalai.core import global_context as gpc
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser
from torchvision.models import resnet18 from torchvision.models import resnet18
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial
# Config # Config
...@@ -37,18 +38,15 @@ CONFIG = dict( ...@@ -37,18 +38,15 @@ CONFIG = dict(
) )
def run_no_pipeline(): def run_engine(rank, world_size):
parser = get_default_parser()
args = parser.parse_args()
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=args.rank, rank=rank,
world_size=args.world_size, world_size=world_size,
host=args.host, host='localhost',
port=args.port, port=29910,
backend=args.backend backend='nccl'
) )
# build model # build model
...@@ -69,8 +67,6 @@ def run_no_pipeline(): ...@@ -69,8 +67,6 @@ def run_no_pipeline():
train_dataloader = get_dataloader(dataset=train_dataset, train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True, shuffle=True,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True) drop_last=True)
# build optimizer # build optimizer
...@@ -102,12 +98,14 @@ def run_no_pipeline(): ...@@ -102,12 +98,14 @@ def run_no_pipeline():
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
report_memory_usage("After testing") report_memory_usage("After testing")
torch.cuda.empty_cache()
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
run_no_pipeline() world_size = 4
run_func = partial(run_engine, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import os.path as osp import os.path as osp
from pathlib import Path from pathlib import Path
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms from torchvision import transforms
from torch.optim import Adam from torch.optim import Adam
...@@ -15,6 +16,7 @@ from colossalai.utils import report_memory_usage, get_dataloader ...@@ -15,6 +16,7 @@ from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser from colossalai.initialize import get_default_parser
from torchvision.models import resnet18 from torchvision.models import resnet18
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial
# Config # Config
...@@ -36,18 +38,15 @@ CONFIG = dict( ...@@ -36,18 +38,15 @@ CONFIG = dict(
) )
def run_no_pipeline(): def run_engine(rank, world_size):
parser = get_default_parser()
args = parser.parse_args()
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=args.rank, rank=rank,
world_size=args.world_size, world_size=world_size,
host=args.host, host='localhost',
port=args.port, port=29911,
backend=args.backend backend='nccl'
) )
# build model # build model
...@@ -68,8 +67,6 @@ def run_no_pipeline(): ...@@ -68,8 +67,6 @@ def run_no_pipeline():
train_dataloader = get_dataloader(dataset=train_dataset, train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True, shuffle=True,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True) drop_last=True)
# build optimizer # build optimizer
...@@ -101,12 +98,14 @@ def run_no_pipeline(): ...@@ -101,12 +98,14 @@ def run_no_pipeline():
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
report_memory_usage("After testing") report_memory_usage("After testing")
torch.cuda.empty_cache()
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
run_no_pipeline() world_size = 4
run_func = partial(run_engine, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import os.path as osp import os.path as osp
from pathlib import Path from pathlib import Path
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms from torchvision import transforms
from torch.optim import Adam from torch.optim import Adam
...@@ -15,6 +16,7 @@ from colossalai.utils import report_memory_usage, get_dataloader ...@@ -15,6 +16,7 @@ from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser from colossalai.initialize import get_default_parser
from torchvision.models import resnet18 from torchvision.models import resnet18
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial
# Config # Config
...@@ -33,18 +35,15 @@ CONFIG = dict( ...@@ -33,18 +35,15 @@ CONFIG = dict(
) )
def run_no_pipeline(): def run_engine(rank, world_size):
parser = get_default_parser()
args = parser.parse_args()
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=args.rank, rank=rank,
world_size=args.world_size, world_size=world_size,
host=args.host, host='localhost',
port=args.port, port=29912,
backend=args.backend backend='nccl'
) )
# build model # build model
...@@ -65,8 +64,6 @@ def run_no_pipeline(): ...@@ -65,8 +64,6 @@ def run_no_pipeline():
train_dataloader = get_dataloader(dataset=train_dataset, train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True, shuffle=True,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True) drop_last=True)
# build optimizer # build optimizer
...@@ -98,12 +95,14 @@ def run_no_pipeline(): ...@@ -98,12 +95,14 @@ def run_no_pipeline():
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
report_memory_usage("After testing") report_memory_usage("After testing")
torch.cuda.empty_cache()
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
run_no_pipeline() world_size = 4
run_func = partial(run_engine, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import os.path as osp import os.path as osp
from pathlib import Path from pathlib import Path
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms from torchvision import transforms
from torch.optim import Adam from torch.optim import Adam
...@@ -15,6 +16,7 @@ from colossalai.utils import report_memory_usage, get_dataloader ...@@ -15,6 +16,7 @@ from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser from colossalai.initialize import get_default_parser
from torchvision.models import resnet18 from torchvision.models import resnet18
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from functools import partial
# Config # Config
...@@ -34,18 +36,15 @@ CONFIG = dict( ...@@ -34,18 +36,15 @@ CONFIG = dict(
) )
def run_no_pipeline(): def run_engine(rank, world_size):
parser = get_default_parser()
args = parser.parse_args()
# init dist env # init dist env
colossalai.launch( colossalai.launch(
config=CONFIG, config=CONFIG,
rank=args.rank, rank=rank,
world_size=args.world_size, world_size=world_size,
host=args.host, host='localhost',
port=args.port, port=29913,
backend=args.backend backend='nccl'
) )
# build model # build model
...@@ -66,8 +65,6 @@ def run_no_pipeline(): ...@@ -66,8 +65,6 @@ def run_no_pipeline():
train_dataloader = get_dataloader(dataset=train_dataset, train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True, shuffle=True,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
drop_last=True) drop_last=True)
# build optimizer # build optimizer
...@@ -99,12 +96,14 @@ def run_no_pipeline(): ...@@ -99,12 +96,14 @@ def run_no_pipeline():
gpc.destroy() gpc.destroy()
logger.info('Test engine finished') logger.info('Test engine finished')
report_memory_usage("After testing") report_memory_usage("After testing")
torch.cuda.empty_cache()
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
run_no_pipeline() world_size = 4
run_func = partial(run_engine, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env sh
test_file=$1
python $test_file --rank $SLURM_PROCID --world_size $SLURM_NPROCS --host $HOST --port 29500
\ No newline at end of file
from tests.test_layers.test_3d.common import IMG_SIZE
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
...@@ -7,7 +6,7 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -7,7 +6,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE
def check_linear_col(): def check_linear_col():
......
...@@ -2,10 +2,13 @@ ...@@ -2,10 +2,13 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import pytest import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch, get_default_parser from colossalai.initialize import launch, get_default_parser
from test_layer import * from functools import partial
from checks_1d.check_layer_1d import *
CONFIG = dict( CONFIG = dict(
parallel=dict( parallel=dict(
...@@ -18,8 +21,14 @@ CONFIG = dict( ...@@ -18,8 +21,14 @@ CONFIG = dict(
) )
def check_layer(): def check_layer(rank, world_size):
# print_rank_0('start check_linear_col') launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29920,
backend='nccl')
check_linear_col() check_linear_col()
check_linear_row() check_linear_row()
check_attention() check_attention()
...@@ -28,21 +37,15 @@ def check_layer(): ...@@ -28,21 +37,15 @@ def check_layer():
check_embed() check_embed()
check_head() check_head()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_1d(): def test_1d():
parser = get_default_parser() world_size = 2
args = parser.parse_args() run_func = partial(check_layer, world_size=world_size)
launch(config=CONFIG, mp.spawn(run_func, nprocs=world_size)
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend)
check_layer()
gpc.destroy()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,7 +5,7 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -5,7 +5,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn import Linear2D, LayerNorm2D, TransformerSelfAttention2D, TransformerMLP2D, TransformerLayer2D from colossalai.nn import Linear2D, LayerNorm2D, TransformerSelfAttention2D, TransformerMLP2D, TransformerLayer2D
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
def check_linear(): def check_linear():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment