Unverified Commit 9bcd2fd4 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] a shorter shard and replicate spec (#1245)

parent 2699dfbb
...@@ -5,7 +5,7 @@ from functools import partial ...@@ -5,7 +5,7 @@ from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
...@@ -13,7 +13,7 @@ import colossalai ...@@ -13,7 +13,7 @@ import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ProcessGroup from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
...@@ -159,7 +159,7 @@ def run_check_shared_param(): ...@@ -159,7 +159,7 @@ def run_check_shared_param():
# They are all Linear, so both row is allowed. This should pass check. # They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row') init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col. # This should be detected by check because you can not set weight as row while set bias as col.
col_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# TODO(jiaruifang) optimize this line # TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized: if not model.cls.predictions.bias.has_initialized:
......
...@@ -4,7 +4,7 @@ import colossalai ...@@ -4,7 +4,7 @@ import colossalai
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp import torch.multiprocessing as mp
from functools import partial from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
...@@ -47,7 +47,7 @@ def check_element_wise_ops(): ...@@ -47,7 +47,7 @@ def check_element_wise_ops():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2) t = torch.rand(2, 2)
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()]))) x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda()) check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda()) assert torch.equal(x.cuda(), t.cuda())
......
...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc ...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial from functools import partial
...@@ -55,7 +55,7 @@ def _run_operand(world_size): ...@@ -55,7 +55,7 @@ def _run_operand(world_size):
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg)) t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t.set_dist_spec(distspec.shard([0], [world_size])) t.set_dist_spec(ShardSpec([0], [world_size]))
t_new = torch.zeros_like(t) t_new = torch.zeros_like(t)
assert isinstance(t_new, ColoTensor) assert isinstance(t_new, ColoTensor)
assert t_new.is_sharded() assert t_new.is_sharded()
...@@ -69,7 +69,7 @@ def _run_view(world_size): ...@@ -69,7 +69,7 @@ def _run_view(world_size):
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
t = ColoTensor.from_torch_tensor( t = ColoTensor.from_torch_tensor(
t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()]))) t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
assert t.size_global()[0] == 4 * world_size assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5 assert t.size_global(1) == 5
...@@ -82,7 +82,7 @@ def _run_view(world_size): ...@@ -82,7 +82,7 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size): def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()]) shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr) tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate()) t.set_dist_spec(distspec.replicate())
......
...@@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer ...@@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
def check_param_equal(model, torch_model, pg: ProcessGroup): def check_param_equal(model, torch_model, pg: ProcessGroup):
...@@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): ...@@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def init_1d_row_spec(model, pg: ProcessGroup): def init_1d_row_spec(model, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
...@@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup): ...@@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup): def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
......
...@@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use ...@@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
...@@ -81,7 +81,7 @@ class MLP(nn.Module): ...@@ -81,7 +81,7 @@ class MLP(nn.Module):
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n: if 'weight' in n:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment