Unverified Commit 9bcd2fd4 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] a shorter shard and replicate spec (#1245)

parent 2699dfbb
import torch
from torch.fx.node import map_arg
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
......@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr(weight, "fx_attr", spec)
return weight
......
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec, ColoTensorSpec
from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad
......@@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
......@@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
return output
......@@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec
mat1 = mat1.redistribute(distspec.replicate())
mat1 = mat1.redistribute(ReplicaSpec())
mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]),
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
......
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
......@@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor,
weight,
......@@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]),
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
......@@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
......@@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
return output
......
......@@ -2,7 +2,7 @@ import torch.nn.functional as F
from typing import Optional
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
......@@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding_bag(input_tensor,
weight,
......@@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.compute_spec.output_replicate:
......
from typing import List, Optional
import torch.nn.functional as F
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor
......@@ -16,7 +16,7 @@ def colo_layernorm(
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
......
......@@ -3,8 +3,7 @@ from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
......@@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# Input:S[1]
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
# Output:P
partial_output = F.linear(input_tensor, weight)
......@@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
return output
......@@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel,
spec=ColoTensorSpec(weight.get_process_group(),
distspec.shard([-1], [weight.get_tp_world_size()]),
ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate:
return output.to_replicate()
......
......@@ -7,16 +7,6 @@ class ColoModule(object):
def __init__(self):
self._shard_params: List[str] = []
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]):
......
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoEmbedding(ColoModule):
......@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]),
'weight': ShardSpec([0], [pg.tp_world_size()]),
},
mode='row',
)
......@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'weight': ShardSpec([-1], [pg.tp_world_size()]),
},
mode='col',
)
......
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule):
......@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
......@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()])
'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': ShardSpec([0], [pg.tp_world_size()])
},
mode='col',
)
......
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
......@@ -11,5 +14,5 @@ from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec'
'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
]
......@@ -5,7 +5,7 @@ import torch
from functools import lru_cache
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
......@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
"""
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
......@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.dist_spec = ReplicaSpec()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
......@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=distspec.replicate())
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.redistribute(distspec.replicate())
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
......@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
"""
if self.is_replicate():
return super().view(*args)
replicated_t = self.redistribute(dist_spec=distspec.replicate())
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
......
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
......
......@@ -4,7 +4,7 @@ import pytest
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import distspec
from colossalai.tensor import ShardSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
......@@ -37,13 +37,13 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)
......
......@@ -4,10 +4,9 @@ import torch.distributed as dist
import pytest
import colossalai
import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
......@@ -18,10 +17,10 @@ def run():
depth = int(math.sqrt(size))
assert depth == math.sqrt(size)
x = torch.rand(8, 8).cuda()
old_dist_spec = distspec.replicate()
row_spec = distspec.shard([0], [size])
col_spec = distspec.shard([-1], [size])
mat_spec = distspec.shard([0, 1], [depth, depth])
old_dist_spec = ReplicaSpec()
row_spec = ShardSpec([0], [size])
col_spec = ShardSpec([-1], [size])
mat_spec = ShardSpec([0, 1], [depth, depth])
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
......@@ -40,8 +39,8 @@ def check_mem():
x = torch.rand(32, 32).cuda()
orig_mem = x.numel() * x.element_size()
assert torch.cuda.memory_allocated() == orig_mem
old_dist_spec = distspec.replicate()
row_spec = distspec.shard([0], [size])
old_dist_spec = ReplicaSpec()
row_spec = ShardSpec([0], [size])
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
assert x.size(0) == 32 // size and x.size(1) == 32
assert torch.cuda.memory_allocated() == orig_mem // size
......
import torch
from colossalai.tensor import distspec, ColoParameter
from colossalai.tensor import ShardSpec, ColoParameter
from torch.nn import functional as F
from functools import partial
......@@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
......
import torch
from colossalai.tensor import ColoTensor, distspec
from colossalai.tensor import ColoTensor, ShardSpec
from torch.nn import functional as F
from functools import partial
......@@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
......
......@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
......@@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
......@@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
......
import torch
from colossalai.tensor import ColoTensor, distspec
from colossalai.tensor import ColoTensor, ShardSpec
from functools import partial
......@@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)
......
......@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ComputeSpec, ComputePattern
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def check_cross_entropy():
......@@ -22,7 +22,7 @@ def check_cross_entropy():
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
......
......@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer
......@@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_col_linear(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_row_embedding(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
def init_1d_col_embedding(weight, pg):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_process_group(pg)
weight.set_tensor_spec(*spec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment