Unverified Commit 9bcd2fd4 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[tensor] a shorter shard and replicate spec (#1245)

parent 2699dfbb
import torch import torch
from torch.fx.node import map_arg from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern, ShardSpec
from colossalai.tensor import ColoTensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter: def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter:
...@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter ...@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
spec = ColoTensorSpec(pg, distspec.shard([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor. # As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr(weight, "fx_attr", spec) setattr(weight, "fx_attr", spec)
return weight return weight
......
import torch import torch
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor
from colossalai.tensor import distspec, ColoTensorSpec from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, Number, convert_to_colo_tensor from ._utils import GeneralTensor, Number, convert_to_colo_tensor
from ._utils import reduce_input, reduce_grad from ._utils import reduce_input, reduce_grad
...@@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso ...@@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P # mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res # beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()])) mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]))
# Output:P # Output:P
partial_output = torch.mm(mat1, mat2) partial_output = torch.mm(mat1, mat2)
...@@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso ...@@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# input # input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(distspec.replicate())) output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(ReplicaSpec()))
return output return output
...@@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso ...@@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec compute_spec = mat2.compute_spec
mat1 = mat1.redistribute(distspec.replicate()) mat1 = mat1.redistribute(ReplicaSpec())
mat1 = reduce_grad(mat1, mat1.get_process_group()) mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = ColoTensorSpec(input_tensor.get_process_group(), distspec.shard([-1], [mat2.get_tp_world_size()]), output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
......
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
...@@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, ...@@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor: sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor, output_parallel = F.embedding(input_tensor,
weight, weight,
...@@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, ...@@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type=norm_type, norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq, scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse) sparse=sparse)
output_spec = ColoTensorSpec(weight.get_process_group(), distspec.shard([-1], [weight.get_tp_world_size()]), output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)) ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
...@@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, ...@@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here # Find index in this shard and mask those not here
# Reduce all # Reduce all
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) # tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank() tensor_parallel_rank = weight.get_process_group().tp_local_rank()
...@@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, ...@@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0. partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, weight.get_process_group()) output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), distspec.replicate())) output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
return output return output
......
...@@ -2,7 +2,7 @@ import torch.nn.functional as F ...@@ -2,7 +2,7 @@ import torch.nn.functional as F
from typing import Optional from typing import Optional
from torch import Tensor from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, ShardSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
...@@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, ...@@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.replicate()) input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding_bag(input_tensor, output_parallel = F.embedding_bag(input_tensor,
weight, weight,
...@@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, ...@@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights=per_sample_weights, per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset, include_last_offset=include_last_offset,
padding_idx=padding_idx) padding_idx=padding_idx)
output_spec = ColoTensorSpec(pg, distspec.shard([-1], [weight.get_tp_world_size()]), output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.compute_spec.output_replicate: if weight.compute_spec.output_replicate:
......
from typing import List, Optional from typing import List, Optional
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
...@@ -16,7 +16,7 @@ def colo_layernorm( ...@@ -16,7 +16,7 @@ def colo_layernorm(
assert isinstance(weight, ColoTensor) assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group()) bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.redistribute(distspec.replicate()) input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group())) output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
......
...@@ -3,8 +3,7 @@ from typing import Optional ...@@ -3,8 +3,7 @@ from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import reduce_input, reduce_grad from ._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
...@@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option ...@@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
pg = weight.get_process_group() pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()])) input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]))
# Output:P # Output:P
partial_output = F.linear(input_tensor, weight) partial_output = F.linear(input_tensor, weight)
...@@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option ...@@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate())) output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
return output return output
...@@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option ...@@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
compute_spec = weight.compute_spec compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(distspec.replicate())
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group()) input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias) output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel, output = ColoTensor.from_torch_tensor(output_parallel,
spec=ColoTensorSpec(weight.get_process_group(), spec=ColoTensorSpec(weight.get_process_group(),
distspec.shard([-1], [weight.get_tp_world_size()]), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))) ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate: if compute_spec.output_replicate:
return output.to_replicate() return output.to_replicate()
......
...@@ -7,16 +7,6 @@ class ColoModule(object): ...@@ -7,16 +7,6 @@ class ColoModule(object):
def __init__(self): def __init__(self):
self._shard_params: List[str] = [] self._shard_params: List[str] = []
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
def _register_shard_params(self, params: List[str]): def _register_shard_params(self, params: List[str]):
......
from .colo_module import ColoModule from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
class ColoEmbedding(ColoModule): class ColoEmbedding(ColoModule):
...@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule): ...@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]), 'weight': ShardSpec([0], [pg.tp_world_size()]),
}, },
mode='row', mode='row',
) )
...@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule): ...@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]), 'weight': ShardSpec([-1], [pg.tp_world_size()]),
}, },
mode='col', mode='col',
) )
......
from .colo_module import ColoModule from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec, ProcessGroup from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule): class ColoLinear(ColoModule):
...@@ -19,7 +19,7 @@ class ColoLinear(ColoModule): ...@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([-1], [pg.tp_world_size()]), 'weight': ShardSpec([-1], [pg.tp_world_size()]),
'bias': None 'bias': None
}, },
mode='row', mode='row',
...@@ -29,8 +29,8 @@ class ColoLinear(ColoModule): ...@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns( self._register_allowed_patterns(
compute_pattern=_compute_pattern, compute_pattern=_compute_pattern,
dist_specs={ dist_specs={
'weight': distspec.shard([0], [pg.tp_world_size()]), 'weight': ShardSpec([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()]) 'bias': ShardSpec([0], [pg.tp_world_size()])
}, },
mode='col', mode='col',
) )
......
from .process_group import ProcessGroup from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec from .tensor_spec import ColoTensorSpec
from .distspec import shard as ShardSpec
from .distspec import replicate as ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter from .colo_parameter import ColoParameter
...@@ -11,5 +14,5 @@ from . import distspec ...@@ -11,5 +14,5 @@ from . import distspec
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec' 'ColoTensorSpec', 'TensorSpec', 'ShardSpec', 'ReplicaSpec'
] ]
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from functools import lru_cache from functools import lru_cache
from colossalai.tensor import ColoTensorSpec from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable from typing import Optional, Set, Callable
...@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor): ...@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()). spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
The signature of the function has to be consistent with the __new__ except for the 1st arg. The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways. The class should be initialized with a torch tensor in the following ways.
1. directly init. 1. directly init.
>>> pg = ProcessGroup() >>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate()) >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor. >>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size), >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0], >>> dims=[0],
>>> num_partitions=[world_size]) >>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec) >>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor 2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate()) >>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
""" """
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
...@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor): ...@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
# If not set spec, use a DP process group and replicate dist spec # If not set spec, use a DP process group and replicate dist spec
if spec is None: if spec is None:
self.has_initialized = False self.has_initialized = False
self.dist_spec = distspec.replicate() self.dist_spec = ReplicaSpec()
self.compute_spec = None self.compute_spec = None
self.process_group = ProcessGroup() self.process_group = ProcessGroup()
else: else:
...@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor): ...@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
"""to_replicate_ """to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE an inline member function, converting dist spec of the tensor to REPLICATE
""" """
self._redistribute(dist_spec=distspec.replicate()) self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor': def to_replicate(self) -> 'ColoTensor':
"""to_replicate """to_replicate
converting dist spec of the tensor to REPLICATE converting dist spec of the tensor to REPLICATE
""" """
return self.redistribute(distspec.replicate()) return self.redistribute(ReplicaSpec())
@staticmethod @staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor': def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
...@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor): ...@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
""" """
if self.is_replicate(): if self.is_replicate():
return super().view(*args) return super().view(*args)
replicated_t = self.redistribute(dist_spec=distspec.replicate()) replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args) return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None): def size_global(self, args: Optional[int] = None):
......
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup from colossalai.tensor import ColoTensor, ColoParameter, distspec, ProcessGroup, ReplicaSpec
from colossalai.nn.parallel.layers import register_colo_module, \ from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding ColoLinear, ColoEmbedding
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import distspec from colossalai.tensor import ShardSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
...@@ -37,13 +37,13 @@ class Conv1D(nn.Module): ...@@ -37,13 +37,13 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias, pg: ProcessGroup): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec) bias.set_tensor_spec(*spec)
......
...@@ -4,10 +4,9 @@ import torch.distributed as dist ...@@ -4,10 +4,9 @@ import torch.distributed as dist
import pytest import pytest
import colossalai import colossalai
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial from functools import partial
...@@ -18,10 +17,10 @@ def run(): ...@@ -18,10 +17,10 @@ def run():
depth = int(math.sqrt(size)) depth = int(math.sqrt(size))
assert depth == math.sqrt(size) assert depth == math.sqrt(size)
x = torch.rand(8, 8).cuda() x = torch.rand(8, 8).cuda()
old_dist_spec = distspec.replicate() old_dist_spec = ReplicaSpec()
row_spec = distspec.shard([0], [size]) row_spec = ShardSpec([0], [size])
col_spec = distspec.shard([-1], [size]) col_spec = ShardSpec([-1], [size])
mat_spec = distspec.shard([0, 1], [depth, depth]) mat_spec = ShardSpec([0, 1], [depth, depth])
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group) row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec, group)
assert torch.equal(x.chunk(size, 0)[rank], row_shard) assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group)) assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec, group))
...@@ -40,8 +39,8 @@ def check_mem(): ...@@ -40,8 +39,8 @@ def check_mem():
x = torch.rand(32, 32).cuda() x = torch.rand(32, 32).cuda()
orig_mem = x.numel() * x.element_size() orig_mem = x.numel() * x.element_size()
assert torch.cuda.memory_allocated() == orig_mem assert torch.cuda.memory_allocated() == orig_mem
old_dist_spec = distspec.replicate() old_dist_spec = ReplicaSpec()
row_spec = distspec.shard([0], [size]) row_spec = ShardSpec([0], [size])
x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg) x.data = DistSpecManager._shard_as(x, old_dist_spec, row_spec, pg)
assert x.size(0) == 32 // size and x.size(1) == 32 assert x.size(0) == 32 // size and x.size(1) == 32
assert torch.cuda.memory_allocated() == orig_mem // size assert torch.cuda.memory_allocated() == orig_mem // size
......
import torch import torch
from colossalai.tensor import distspec, ColoParameter from colossalai.tensor import ShardSpec, ColoParameter
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
...@@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal ...@@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight, pg: ProcessGroup): def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
......
import torch import torch
from colossalai.tensor import ColoTensor, distspec from colossalai.tensor import ColoTensor, ShardSpec
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
...@@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal ...@@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, pg: ProcessGroup): def init_1d_row(weight, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col(weight, pg: ProcessGroup): def init_1d_col(weight, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
......
...@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use ...@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
...@@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs ...@@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup): def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
...@@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup): ...@@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_col_spec(model, pg: ProcessGroup): def init_1d_col_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
......
import torch import torch
from colossalai.tensor import ColoTensor, distspec from colossalai.tensor import ColoTensor, ShardSpec
from functools import partial from functools import partial
...@@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal ...@@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias, pg: ProcessGroup): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec) bias.set_tensor_spec(*spec)
......
...@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec ...@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, ComputeSpec, ComputePattern from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def check_cross_entropy(): def check_cross_entropy():
...@@ -22,7 +22,7 @@ def check_cross_entropy(): ...@@ -22,7 +22,7 @@ def check_cross_entropy():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()])) input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target) output = F.cross_entropy(input_t, target)
......
...@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use ...@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \ from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
...@@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs ...@@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col_linear(weight, pg): def init_1d_col_linear(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_row_embedding(weight, pg): def init_1d_row_embedding(weight, pg):
spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
def init_1d_col_embedding(weight, pg): def init_1d_col_embedding(weight, pg):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_process_group(pg) weight.set_process_group(pg)
weight.set_tensor_spec(*spec) weight.set_tensor_spec(*spec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment