Unverified Commit 2b2dc1c8 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[pipeline] refactor the pipeline module (#1087)

* [pipeline] refactor the pipeline module

* polish code
parent bad5d4c0
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from colossalai.registry import MODELS
from colossalai.nn.model import ModelFromConfig
@MODELS.register_module
class VanillaResNet(ModelFromConfig):
"""ResNet from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
def __init__(
self,
num_cls: int,
block_type: str,
layers: List[int],
norm_layer_type: str = 'BatchNorm2d',
in_channels: int = 3,
groups: int = 1,
width_per_group: int = 64,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[List[bool]] = None,
dilations=(1, 1, 1, 1)
) -> None:
super().__init__()
self.inplanes = 64
self.zero_init_residual = zero_init_residual
self.blocks = layers
self.block_expansion = LAYERS.get_module(block_type).expansion
self.dilations = dilations
self.reslayer_common_cfg = dict(
type='ResLayer',
block_type=block_type,
norm_layer_type=norm_layer_type,
groups=groups,
base_width=width_per_group
)
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.layers_cfg = [
# conv1
dict(type='Conv2d',
in_channels=in_channels,
out_channels=self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False),
# bn1
dict(
type=norm_layer_type,
num_features=self.inplanes
),
# relu
dict(
type='ReLU',
inplace=True
),
# maxpool
dict(
type='MaxPool2d',
kernel_size=3,
stride=2,
padding=1
),
# layer 1
dict(
inplanes=self.inplanes,
planes=64,
blocks=self.blocks[0],
dilation=self.dilations[0],
**self.reslayer_common_cfg
),
# layer 2
dict(
inplanes=64 * self.block_expansion,
planes=128,
blocks=self.blocks[1],
stride=2,
dilate=replace_stride_with_dilation[0],
dilation=self.dilations[1],
**self.reslayer_common_cfg
),
# layer 3
dict(
inplanes=128 * self.block_expansion,
planes=256,
blocks=layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
dilation=self.dilations[2],
**self.reslayer_common_cfg
),
# layer 4
dict(
inplanes=256 * self.block_expansion,
planes=512,
blocks=layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
dilation=self.dilations[3],
**self.reslayer_common_cfg
),
# avg pool
dict(
type='AdaptiveAvgPool2d',
output_size=(1, 1)
),
# flatten
dict(
type='LambdaWrapper',
func=lambda mod, x: torch.flatten(x, 1)
),
# linear
dict(
type='Linear',
in_features=512 * self.block_expansion,
out_features=num_cls
)
]
def forward(self, x: Tensor):
for layer in self.layers:
x = layer(x)
return x
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, LAYERS.get_module('ResNetBottleneck')):
# type: ignore[arg-type]
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, LAYERS.get_module('ResNetBasicBlock')):
# type: ignore[arg-type]
nn.init.constant_(m.bn2.weight, 0)
import os
import model
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
NUM_MICRO_BATCHES = 2
# resnet 18
model = dict(type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=10)
parallel = dict(
pipeline=dict(size=4),
tensor=dict(size=1, mode=None)
)
import os.path as osp
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.builder.pipeline import build_pipeline_model_from_cfg
from colossalai.core import global_context
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, 'resnet_config.py')
def run_partition(rank, world_size, port):
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
logger.info('finished initialization')
# build model
model = build_pipeline_model_from_cfg(global_context.config.model, 1, verbose=True)
assert isinstance(model, torch.nn.Module)
logger.info('model is created')
global_context.destroy()
logger.info('training finished')
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_partition():
world_size = 4
run_func = partial(run_partition, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_partition()
...@@ -8,27 +8,45 @@ from pathlib import Path ...@@ -8,27 +8,45 @@ from pathlib import Path
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule from colossalai.context import ParallelMode
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0 from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 4
DIR_PATH = osp.dirname(osp.realpath(__file__)) BATCH_SIZE = 8
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
CONFIG=dict(
NUM_MICRO_BATCHES=2,
parallel = dict(
pipeline=dict(size=2),
tensor=dict(size=1, mode=None)
)
)
def run_schedule(rank, world_size, port): def run_schedule(rank, world_size, port):
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model # build model
model = build_pipeline_model_from_cfg(gpc.config.model, 1) model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
print_rank_0('model is created') print_rank_0('model is created')
train_dataset = CIFAR10(root=Path(os.environ['DATA']), train_dataset = CIFAR10(root=Path(os.environ['DATA']),
...@@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port): ...@@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_pipeline_schedule(): def test_pipeline_schedule():
world_size = 4 world_size = 2
run_func = partial(run_schedule, world_size=world_size, port=free_port()) run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus ...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model): def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
......
...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus ...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model): def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
......
...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus ...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model): def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
......
...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus ...@@ -19,7 +19,7 @@ from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model): def build_pipeline(model):
from colossalai.builder.pipeline import partition_uniform from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment