Unverified Commit 9bcd2fd4 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] a shorter shard and replicate spec (#1245)

parent 2699dfbb
......@@ -5,7 +5,7 @@ from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
......@@ -13,7 +13,7 @@ import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
......@@ -159,7 +159,7 @@ def run_check_shared_param():
# They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# TODO(jiaruifang) optimize this line
if not model.cls.predictions.bias.has_initialized:
......
......@@ -4,7 +4,7 @@ import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device
from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
......@@ -47,7 +47,7 @@ def check_element_wise_ops():
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2)
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
x = ColoTensor(t, spec=ColoTensorSpec(pg, ShardSpec([0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda())
......
......@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
......@@ -55,7 +55,7 @@ def _run_operand(world_size):
pg = ProcessGroup(tp_degree=world_size)
t = ColoTensor.from_torch_tensor(t_ref.clone(), ColoTensorSpec(pg))
t.set_dist_spec(distspec.shard([0], [world_size]))
t.set_dist_spec(ShardSpec([0], [world_size]))
t_new = torch.zeros_like(t)
assert isinstance(t_new, ColoTensor)
assert t_new.is_sharded()
......@@ -69,7 +69,7 @@ def _run_view(world_size):
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
t = ColoTensor.from_torch_tensor(
t_ref, ColoTensorSpec(pg, dist_attr=distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])))
t_ref, ColoTensorSpec(pg, dist_attr=ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])))
assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5
......@@ -82,7 +82,7 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
pg = ProcessGroup(tp_degree=world_size)
shard_attr = distspec.shard(dims=[0], num_partitions=[pg.tp_world_size()])
shard_attr = ShardSpec(dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate())
......
......@@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
def check_param_equal(model, torch_model, pg: ProcessGroup):
......@@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def init_1d_row_spec(model, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
......@@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
......
......@@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
......@@ -81,7 +81,7 @@ class MLP(nn.Module):
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment