Unverified Commit 554aa959 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[legacy] move communication and nn to legacy and refactor logger (#4671)

* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
parent 536397cc
......@@ -2,7 +2,7 @@ import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import TransformerSelfAttentionRing
from colossalai.legacy.nn import TransformerSelfAttentionRing
from colossalai.utils import get_current_device
......
......@@ -5,6 +5,7 @@ import torch.distributed as dist
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
......@@ -42,7 +43,7 @@ def check_ring_qk(rank, world_size):
a = torch.matmul(q, k.transpose(2, 1))
# compute distributed attention scores
ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply
ring_qk = RingQK.apply
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attention scores
......@@ -95,7 +96,7 @@ def check_ring_av(rank, world_size):
out = torch.matmul(a, v)
# compute distributed attention scores
ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply
ring_av = RingAV.apply
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
......
......@@ -5,7 +5,10 @@ import pytest
import torch
import torch.distributed as dist
from colossalai.communication import (
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.legacy.communication import (
recv_backward,
recv_forward,
recv_obj_meta,
......@@ -15,9 +18,6 @@ from colossalai.communication import (
send_forward_recv_backward,
send_obj_meta,
)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
......
import os
import time
import pytest
import torch
import torch.nn as nn
from rpc_test_utils import parse_args, rpc_run
from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
from tqdm import tqdm
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.pipeline.rpc import OneFOneBPipelineEngine
def flatten(x):
return torch.flatten(x, 1)
def partition(pp_rank: int, chunk: int, stage_num: int):
pipelinable = PipelinableContext()
# build model partitions
with pipelinable:
# input : [B, 3, 32, 32]
_ = resnet50()
pipelinable.policy = "customized"
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
partition = pipelinable.partition(chunk, stage_num, pp_rank)
return partition
def run_master(args):
batch_size = args.batch_size
chunk = args.chunk
device = args.device
world_size = args.world_size
stage_num = world_size
num_microbatches = args.num_microbatches
# build dataloader
root = os.environ.get('DATA', './data')
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
criterion = nn.CrossEntropyLoss()
pp_engine = OneFOneBPipelineEngine(partition_fn=partition,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
criterion=criterion,
checkpoint=False)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
s = time.time()
for bx, by in tqdm(train_dataloader):
pp_engine.forward_backward(bx, labels=by, forward_only=False)
cost_time = time.time() - s
print("total cost time :", cost_time)
print("cost time per batch:", cost_time / len(train_dataloader))
@pytest.mark.skip("Test for performance, no need for CI")
def main():
args = parse_args()
# this is due to limitation of partition function
args.world_size = 2
args.chunk = 1
rpc_run(args, run_master)
if __name__ == '__main__':
main()
......@@ -7,7 +7,7 @@ import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
......
......@@ -7,7 +7,7 @@ import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
......
......@@ -7,7 +7,7 @@ import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
......
......@@ -7,7 +7,7 @@ import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
import colossalai.legacy.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment