Unverified Commit a0064407 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

reorgnize colotensor directory (#1062)

* reorgnize colotensor directory

* polish code
parent 3d10be33
import torch import torch
from colossalai.tensor.colo_tensor import ColoTensor
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union
import torch.nn as nn import torch.nn as nn
from colossalai.tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
# The function is credited to PyTorch Team # The function is credited to PyTorch Team
......
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \ from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn import register_colo_module, init_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
import types
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union, Optional from typing import Iterator, Tuple, Union
# find named_params includes replica # find named_params includes replica
...@@ -24,6 +25,7 @@ def _named_params_with_replica( ...@@ -24,6 +25,7 @@ def _named_params_with_replica(
name = mod_prefix + ('.' if mod_prefix else '') + name name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val yield name, val
def ColoModulize(module): def ColoModulize(module):
""" """
Replacing the parameters() and named_parameters() with our customized ones Replacing the parameters() and named_parameters() with our customized ones
......
from colossalai.utils import free_port, ColoInitContext, get_current_device from colossalai.utils import free_port, ColoInitContext, get_current_device
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, init_colo_module from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from functools import partial from functools import partial
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn import init_colo_module
from colossalai.nn.parallel import ColoDDP from colossalai.nn.parallel import ColoDDP
import colossalai import colossalai
...@@ -11,12 +14,14 @@ import torch ...@@ -11,12 +14,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import pytest import pytest
class Net(torch.nn.Module): class Net(torch.nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.embed = torch.nn.Embedding(20, 4) self.embed = torch.nn.Embedding(20, 4)
self.proj = torch.nn.Linear(4, 8) self.proj = torch.nn.Linear(4, 8)
def forward(self, x): def forward(self, x):
# move input to cpu and restore output # move input to cpu and restore output
current_dev = x.device current_dev = x.device
...@@ -27,6 +32,7 @@ class Net(torch.nn.Module): ...@@ -27,6 +32,7 @@ class Net(torch.nn.Module):
x = self.proj(x) x = self.proj(x)
return x return x
def run_hybrid_device(use_ddp): def run_hybrid_device(use_ddp):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = Net() model = Net()
...@@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp): ...@@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp):
model = ColoDDP(model) model = ColoDDP(model)
real_model = model.module real_model = model.module
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}') print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ParallelAction(ComputePattern.TP1D)
...@@ -49,11 +54,12 @@ def run_hybrid_device(use_ddp): ...@@ -49,11 +54,12 @@ def run_hybrid_device(use_ddp):
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}') print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}') #print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
out = model(data) out = model(data)
out.sum().backward() out.sum().backward()
def run_dist(rank, world_size, port, use_ddp): def run_dist(rank, world_size, port, use_ddp):
if use_ddp and world_size == 1: if use_ddp and world_size == 1:
return return
...@@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp): ...@@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_hybrid_device(use_ddp) run_hybrid_device(use_ddp)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
...@@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp): ...@@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
_test_hybrid_device(1, False) _test_hybrid_device(1, False)
\ No newline at end of file
...@@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device ...@@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils import ColoInitContext from colossalai.utils import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer
from functools import partial from functools import partial
from _utils import set_seed from _utils import set_seed
......
from copy import copy from copy import copy
from colossalai.utils.cuda import get_current_device import pytest
from colossalai.utils.model.colo_init_context import ColoInitContext
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor, distspec
from functools import partial from functools import partial
import colossalai
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import distspec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
def run_model_with_spec(mode, model_name): def run_model_with_spec(mode, model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -27,7 +31,7 @@ def run_model_with_spec(mode, model_name): ...@@ -27,7 +31,7 @@ def run_model_with_spec(mode, model_name):
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=False) model = model_builder(checkpoint=False)
if rank == 0: if rank == 0:
model_seq = model_builder(checkpoint=False) model_seq = model_builder(checkpoint=False)
model_seq = model_seq.cuda() model_seq = model_seq.cuda()
...@@ -103,15 +107,16 @@ def run_model_with_spec(mode, model_name): ...@@ -103,15 +107,16 @@ def run_model_with_spec(mode, model_name):
if i > 3: if i > 3:
break break
def run_linear_with_spec(mode): def run_linear_with_spec(mode):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = torch.nn.Linear(4, 8) model = torch.nn.Linear(4, 8)
model_handy = copy(model) model_handy = copy(model)
parallel_action = ParallelAction(ComputePattern.TP1D) parallel_action = ParallelAction(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode=mode) init_colo_module(model, parallel_action, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
out = model(x) out = model(x)
colo_out = model_handy(x) colo_out = model_handy(x)
...@@ -122,6 +127,7 @@ def run_linear_with_spec(mode): ...@@ -122,6 +127,7 @@ def run_linear_with_spec(mode):
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad) assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad) assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
def run_check_shared_param(): def run_check_shared_param():
from transformers import BertForMaskedLM, BertConfig from transformers import BertForMaskedLM, BertConfig
hidden_dim = 8 hidden_dim = 8
...@@ -157,12 +163,14 @@ def run_check_shared_param(): ...@@ -157,12 +163,14 @@ def run_check_shared_param():
except Exception as e: except Exception as e:
assert 'incorrectly sharded' in str(e) assert 'incorrectly sharded' in str(e)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_linear_with_spec('col') run_linear_with_spec('col')
run_linear_with_spec('row') run_linear_with_spec('row')
def run_dist_model(rank, world_size, port): def run_dist_model(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
...@@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port): ...@@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port):
run_model_with_spec('col', model_name) run_model_with_spec('col', model_name)
run_model_with_spec('row', model_name) run_model_with_spec('row', model_name)
def run_dist_check(rank, world_size, port): def run_dist_check(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_check_shared_param() run_check_shared_param()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
...@@ -182,6 +192,7 @@ def test_module_linear_1d(world_size): ...@@ -182,6 +192,7 @@ def test_module_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
...@@ -189,6 +200,7 @@ def test_module_model(world_size): ...@@ -189,6 +200,7 @@ def test_module_model(world_size):
run_func = partial(run_dist_model, world_size=world_size, port=free_port()) run_func = partial(run_dist_model, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
...@@ -196,5 +208,6 @@ def test_module_check(world_size): ...@@ -196,5 +208,6 @@ def test_module_check(world_size):
run_func = partial(run_dist_check, world_size=world_size, port=free_port()) run_func = partial(run_dist_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_module_check(2) test_module_check(2)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment