Unverified Commit b5f9e37c authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
parent 32e7f994
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
......
...@@ -7,11 +7,11 @@ from typing import Callable, List, Tuple, Union ...@@ -7,11 +7,11 @@ from typing import Callable, List, Tuple, Union
import torch.cuda import torch.cuda
import colossalai.legacy.communication as comm import colossalai.legacy.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.legacy.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
...@@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(micro_batch_data) return self._move_to_device(micro_batch_data)
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.legacy import ShardedModelV2 from colossalai.legacy.zero import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
......
...@@ -6,8 +6,8 @@ from typing import Iterable, Tuple ...@@ -6,8 +6,8 @@ from typing import Iterable, Tuple
import torch.cuda import torch.cuda
import colossalai.legacy.communication.p2p_v2 as comm import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
......
This diff is collapsed.
from ._ops import *
from .layer import * from .layer import *
from .loss import * from .loss import *
from .metric import * from .metric import *
from .addmm import colo_addmm from ._utils import *
from .batch_norm import colo_batch_norm
from .element_wise import *
from .embedding import colo_embedding
from .embedding_bag import colo_embedding_bag
from .layernorm import colo_layernorm
from .linear import colo_linear
from .loss import colo_cross_entropy
from .view import colo_view
...@@ -3,9 +3,10 @@ from typing import List, Optional, Union ...@@ -3,9 +3,10 @@ from typing import List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn.layer.utils import divide from colossalai.legacy.nn.layer.utils import divide
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup
from colossalai.tensor import ColoTensor
GeneralTensor = Union[ColoTensor, torch.Tensor] GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float] Number = Union[int, float]
......
import torch
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())
# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)
output = reduce_input(partial_output, mat2.get_process_group())
# input
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
return output
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec
mat1 = mat1.redistribute(ReplicaSpec())
mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
return funcs[mode](input_tensor, mat1, mat2, beta, alpha)
@colo_op_impl(torch.addmm)
def colo_addmm(input_tensor: GeneralTensor,
mat1: ColoTensor,
mat2: ColoTensor,
beta: Number = 1,
alpha: Number = 1,
**kargs) -> ColoTensor:
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
# At least one of the tensor should be ColoTensor
assert isinstance(mat2, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group())
mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group())
# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor,
mat1,
mat2,
beta=beta,
alpha=alpha,
**kargs),
spec=ColoTensorSpec(mat2.get_process_group()))
elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.is_shard_1drow() and input_tensor.is_replicate():
mode = 'row'
elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()):
mode = 'col'
else:
raise NotImplementedError
ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError
return ret_tensor
from typing import Optional
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(F.batch_norm)
def colo_batch_norm(
input: GeneralTensor,
running_mean: Optional[GeneralTensor],
running_var: Optional[GeneralTensor],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
):
assert isinstance(weight, ColoTensor)
running_mean = running_mean.detach()
running_var = running_var.detach()
input = convert_to_colo_tensor(input, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input = input.redistribute(ReplicaSpec())
bias = bias.redistribute(ReplicaSpec())
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
return output
import torch
import torch.nn.functional as F
from torch import Tensor
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
def register_elementwise_op(op):
@colo_op_impl(op)
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
if 'inplace' in kwargs:
# TODO(jiaruifang) inplace will cause bugs
input_tensor = input_tensor.clone()
return op(input_tensor, *args, **kwargs)
else:
output = op(input_tensor, *args, **kwargs)
# return output
if isinstance(input_tensor, ColoTensor):
if isinstance(output, str):
return output
if not isinstance(output, torch.Tensor):
raise NotImplementedError
return ColoTensor.from_torch_tensor(output,
spec=ColoTensorSpec(input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
# @colo_op_impl(torch.relu_)
# def elementwise_op(input_tensor):
# torch.relu_(input_tensor.data)
# return input_tensor
# @colo_op_impl(Tensor.add_)
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
# input_tensor = input_tensor.data.add_(*args, **kwargs)
# return input_tensor
# Tensor op
register_elementwise_op(Tensor.abs)
register_elementwise_op(Tensor.absolute)
register_elementwise_op(Tensor.acos)
register_elementwise_op(Tensor.arccos)
register_elementwise_op(Tensor.angle)
register_elementwise_op(Tensor.asin)
register_elementwise_op(Tensor.arcsin)
register_elementwise_op(Tensor.atan)
register_elementwise_op(Tensor.arctan)
register_elementwise_op(Tensor.all)
register_elementwise_op(Tensor.any)
register_elementwise_op(Tensor.bernoulli)
register_elementwise_op(Tensor.bfloat16)
register_elementwise_op(Tensor.bitwise_not)
register_elementwise_op(Tensor.bool)
register_elementwise_op(Tensor.byte)
register_elementwise_op(Tensor.ceil)
register_elementwise_op(Tensor.char)
register_elementwise_op(Tensor.clamp)
register_elementwise_op(Tensor.clamp_max)
register_elementwise_op(Tensor.clamp_min)
register_elementwise_op(Tensor.clip)
register_elementwise_op(Tensor.clone)
register_elementwise_op(Tensor.contiguous)
register_elementwise_op(Tensor.copysign)
register_elementwise_op(Tensor.cos)
register_elementwise_op(Tensor.cosh)
register_elementwise_op(Tensor.acosh)
register_elementwise_op(Tensor.arccosh)
register_elementwise_op(Tensor.cpu)
register_elementwise_op(Tensor.cuda)
register_elementwise_op(Tensor.deg2rad)
register_elementwise_op(Tensor.detach)
register_elementwise_op(Tensor.digamma)
register_elementwise_op(Tensor.double)
register_elementwise_op(Tensor.erf)
register_elementwise_op(Tensor.erfc)
register_elementwise_op(Tensor.erfinv)
register_elementwise_op(Tensor.exp)
register_elementwise_op(Tensor.expm1)
register_elementwise_op(Tensor.fix)
register_elementwise_op(Tensor.trunc)
register_elementwise_op(Tensor.float)
register_elementwise_op(Tensor.float_power)
register_elementwise_op(Tensor.floor)
register_elementwise_op(Tensor.frac)
register_elementwise_op(Tensor.half)
register_elementwise_op(Tensor.hardshrink)
register_elementwise_op(Tensor.heaviside)
register_elementwise_op(Tensor.i0)
register_elementwise_op(Tensor.int)
register_elementwise_op(Tensor.isfinite)
register_elementwise_op(Tensor.isinf)
register_elementwise_op(Tensor.isposinf)
register_elementwise_op(Tensor.isneginf)
register_elementwise_op(Tensor.isnan)
register_elementwise_op(Tensor.lgamma)
register_elementwise_op(Tensor.log)
register_elementwise_op(Tensor.log10)
register_elementwise_op(Tensor.log1p)
register_elementwise_op(Tensor.log2)
register_elementwise_op(Tensor.logical_not)
register_elementwise_op(Tensor.logit)
register_elementwise_op(Tensor.long)
register_elementwise_op(Tensor.nan_to_num)
register_elementwise_op(Tensor.neg)
register_elementwise_op(Tensor.negative)
register_elementwise_op(Tensor.positive)
register_elementwise_op(Tensor.pow)
register_elementwise_op(Tensor.rad2deg)
register_elementwise_op(Tensor.reciprocal)
register_elementwise_op(Tensor.round)
register_elementwise_op(Tensor.rsqrt)
register_elementwise_op(Tensor.short)
register_elementwise_op(Tensor.sigmoid)
register_elementwise_op(Tensor.sign)
register_elementwise_op(Tensor.signbit)
register_elementwise_op(Tensor.sgn)
register_elementwise_op(Tensor.sin)
register_elementwise_op(Tensor.sinc)
register_elementwise_op(Tensor.sinh)
register_elementwise_op(Tensor.asinh)
register_elementwise_op(Tensor.arcsinh)
register_elementwise_op(Tensor.sqrt)
register_elementwise_op(Tensor.square)
register_elementwise_op(Tensor.to)
register_elementwise_op(Tensor.tan)
register_elementwise_op(Tensor.tanh)
register_elementwise_op(Tensor.atanh)
register_elementwise_op(Tensor.arctanh)
register_elementwise_op(Tensor.type)
register_elementwise_op(Tensor.type_as)
# torch OP
register_elementwise_op(torch.abs)
register_elementwise_op(torch.absolute)
register_elementwise_op(torch.acos)
register_elementwise_op(torch.arccos)
register_elementwise_op(torch.angle)
register_elementwise_op(torch.asin)
register_elementwise_op(torch.arcsin)
register_elementwise_op(torch.atan)
register_elementwise_op(torch.arctan)
register_elementwise_op(torch.all)
register_elementwise_op(torch.any)
register_elementwise_op(torch.bernoulli)
register_elementwise_op(torch.bitwise_not)
register_elementwise_op(torch.ceil)
register_elementwise_op(torch.clamp)
register_elementwise_op(torch.clamp_max)
register_elementwise_op(torch.clamp_min)
register_elementwise_op(torch.clip)
register_elementwise_op(torch.clone)
register_elementwise_op(torch.copysign)
register_elementwise_op(torch.cos)
register_elementwise_op(torch.cosh)
register_elementwise_op(torch.acosh)
register_elementwise_op(torch.arccosh)
register_elementwise_op(torch.deg2rad)
register_elementwise_op(torch.digamma)
register_elementwise_op(torch.erf)
register_elementwise_op(torch.erfc)
register_elementwise_op(torch.erfinv)
register_elementwise_op(torch.exp)
register_elementwise_op(torch.expm1)
register_elementwise_op(torch.fix)
register_elementwise_op(torch.trunc)
register_elementwise_op(torch.float_power)
register_elementwise_op(torch.floor)
register_elementwise_op(torch.frac)
register_elementwise_op(torch.hardshrink)
register_elementwise_op(torch.heaviside)
register_elementwise_op(torch.i0)
register_elementwise_op(torch.isfinite)
register_elementwise_op(torch.isinf)
register_elementwise_op(torch.isposinf)
register_elementwise_op(torch.isneginf)
register_elementwise_op(torch.isnan)
register_elementwise_op(torch.lgamma)
register_elementwise_op(torch.log)
register_elementwise_op(torch.log10)
register_elementwise_op(torch.log1p)
register_elementwise_op(torch.log2)
register_elementwise_op(torch.logical_not)
register_elementwise_op(torch.logit)
register_elementwise_op(torch.nan_to_num)
register_elementwise_op(torch.neg)
register_elementwise_op(torch.negative)
register_elementwise_op(torch.positive)
register_elementwise_op(torch.pow)
register_elementwise_op(torch.rad2deg)
register_elementwise_op(torch.reciprocal)
register_elementwise_op(torch.round)
register_elementwise_op(torch.rsqrt)
register_elementwise_op(torch.sigmoid)
register_elementwise_op(torch.sign)
register_elementwise_op(torch.signbit)
register_elementwise_op(torch.sgn)
register_elementwise_op(torch.sin)
register_elementwise_op(torch.sinc)
register_elementwise_op(torch.sinh)
register_elementwise_op(torch.asinh)
register_elementwise_op(torch.arcsinh)
register_elementwise_op(torch.sqrt)
register_elementwise_op(torch.square)
register_elementwise_op(torch.tan)
register_elementwise_op(torch.tanh)
register_elementwise_op(torch.atanh)
register_elementwise_op(torch.arctanh)
register_elementwise_op(torch.zeros_like)
# nn.functional OP
register_elementwise_op(F.threshold)
register_elementwise_op(F.relu)
register_elementwise_op(F.hardtanh)
register_elementwise_op(F.hardswish)
register_elementwise_op(F.relu6)
register_elementwise_op(F.elu)
register_elementwise_op(F.selu)
register_elementwise_op(F.celu)
register_elementwise_op(F.leaky_relu)
register_elementwise_op(F.prelu)
register_elementwise_op(F.rrelu)
register_elementwise_op(F.gelu)
register_elementwise_op(F.logsigmoid)
register_elementwise_op(F.hardshrink)
register_elementwise_op(F.tanhshrink)
register_elementwise_op(F.softsign)
register_elementwise_op(F.softplus)
register_elementwise_op(F.softmin)
register_elementwise_op(F.softmax)
register_elementwise_op(F.softshrink)
register_elementwise_op(F.gumbel_softmax)
register_elementwise_op(F.log_softmax)
register_elementwise_op(F.tanh)
register_elementwise_op(F.sigmoid)
register_elementwise_op(F.hardsigmoid)
register_elementwise_op(F.silu)
register_elementwise_op(F.mish)
# TODO(ver217): dropout handles seed
register_elementwise_op(F.dropout)
register_elementwise_op(F.alpha_dropout)
register_elementwise_op(F.feature_alpha_dropout)
from typing import Optional
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
def colo_embedding_1Dcol(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = ColoTensorSpec(weight.get_process_group(), ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
compute_spec = weight.compute_spec
if compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_embedding_1Drow(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
# get the index of current segment and mask other segments with 0
# get complete input tensor through all-gather
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
num_embeddings_per_partition = weight.size_local(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# build the mask.
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
# mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor - vocab_start_index
masked_input[input_mask] = 0
partial_output = F.embedding(masked_input,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, weight.get_process_group())
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(weight.get_process_group(), ReplicaSpec()))
return output
def colo_embedding_1d(mode: str,
input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
return funcs[mode](input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
@colo_op_impl(F.embedding)
def colo_embedding(input_tensor: GeneralTensor,
weight: GeneralTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method looks up an embedding table.
"""
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1drow():
mode = 'row'
elif weight.is_shard_1dcol():
mode = 'col'
else:
raise NotImplementedError
return colo_embedding_1d(mode,
input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
else:
raise NotImplementedError
from typing import Optional
import torch.nn.functional as F
from torch import Tensor
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
weight: ColoTensor,
offsets: Optional[Tensor] = None,
max_norm: Optional[float] = None,
norm_type: float = 2,
scale_grad_by_freq: bool = False,
mode: str = "mean",
sparse: bool = False,
per_sample_weights: Optional[Tensor] = None,
include_last_offset: bool = False,
padding_idx: Optional[int] = None) -> ColoTensor:
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(ReplicaSpec())
output_parallel = F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
output_spec = ColoTensorSpec(pg, ShardSpec([-1], [weight.get_tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if weight.compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_embedding_bag_1d(tp_mode: str,
input_tensor: ColoTensor,
weight: ColoTensor,
offsets: Optional[Tensor] = None,
max_norm: Optional[float] = None,
norm_type: float = 2,
scale_grad_by_freq: bool = False,
mode: str = "mean",
sparse: bool = False,
per_sample_weights: Optional[Tensor] = None,
include_last_offset: bool = False,
padding_idx: Optional[int] = None) -> ColoTensor:
assert tp_mode in ('col',)
funcs = {'col': colo_embedding_bag_1Dcol}
return funcs[tp_mode](input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
@colo_op_impl(F.embedding_bag)
def colo_embedding_bag(input_tensor: GeneralTensor,
weight: GeneralTensor,
offsets: Optional[Tensor] = None,
max_norm: Optional[float] = None,
norm_type: float = 2,
scale_grad_by_freq: bool = False,
mode: str = "mean",
sparse: bool = False,
per_sample_weights: Optional[Tensor] = None,
include_last_offset: bool = False,
padding_idx: Optional[int] = None):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
This method looks up an embedding table.
"""
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
# Handle different parallel actions.
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx),
spec=ColoTensorSpec(weight.get_process_group()))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol():
tp_mode = 'col'
else:
raise NotImplementedError
return colo_embedding_bag_1d(tp_mode,
input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
else:
raise NotImplementedError
from typing import List, Optional
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(F.layer_norm)
def colo_layernorm(
input_tensor: GeneralTensor,
normalized_shape: List[int],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
eps: float = 1e-5,
):
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.redistribute(ReplicaSpec())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(tensor=output,
spec=ColoTensorSpec(pg=input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
return output
from copy import deepcopy
from typing import Optional
import torch.nn.functional as F
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor.sharding_spec import ShardingSpec
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_grad, reduce_input
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
pg = weight.get_process_group()
input_tensor = input_tensor.redistribute(ShardSpec([-1], [weight.get_tp_world_size()]), pg)
# Output:P
partial_output = F.linear(input_tensor, weight)
# Reduce(Output)
output = reduce_input(partial_output, pg)
# Bias
if bias is not None:
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, ReplicaSpec()))
return output
def colo_linear_1dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.redistribute(ReplicaSpec())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel,
spec=ColoTensorSpec(weight.get_process_group(),
ShardSpec([-1], [weight.get_tp_world_size()]),
ComputeSpec(ComputePattern.TP1D)))
if compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
assert mode in ('row', 'col')
funcs = {'row': colo_linear_1drow, 'col': colo_linear_1dcol}
return funcs[mode](input_tensor, weight, bias)
# @register_colo_graph(input_pos=[1], param_pos=[2, 3])
def colo_linear_imp(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
assert isinstance(weight, ColoTensor)
pg = weight.get_process_group()
assert pg
input_tensor = convert_to_colo_tensor(input_tensor, pg)
bias = convert_to_colo_tensor(bias, pg)
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg))
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
mode = 'row'
elif weight.is_shard_1drow() and (bias is None or bias.is_shard_1drow() or bias.is_shard_1dcol()):
mode = 'col'
else:
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight}, bias {bias}")
ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias)
else:
raise NotImplementedError
return ret_tensor
def _new_colo_linear_imp(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
"""
A tentative function to compute the distributed linear layer with the latest sharding spec.
This function is subject to future change as the current sharding API is not stable.
"""
# get mesh info
input_sharding_seq = input_tensor.sharding_spec.sharding_sequence
weight_sharding_seq = weight.sharding_spec.sharding_sequence
if bias is not None:
bias_sharding_seq = bias.sharding_spec.sharding_sequence
device_mesh = weight.sharding_spec.device_mesh
pg_axis0 = weight.pg_axis0
pg_axis1 = weight.pg_axis1
# the last dim of input should have the same spec as the first dim of weight
# the weight is transposed, so we look at the second dimension
assert input_sharding_seq[-1] == weight_sharding_seq[1]
if bias is not None:
assert bias_sharding_seq[0] == weight_sharding_seq[0]
# compute the output sharding sequence
# as weight is transposed, so we look at the first dimension
output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1]
output_shard_seq = deepcopy(output_shard_seq)
# TODO: add reduce grad logic
# handle column and row parallel linear
# by reusing the implementation above
out = F.linear(input_tensor, weight)
# run all reduce if necessary
last_dim_spec = input_sharding_seq[-1]
if last_dim_spec.is_replica:
pass
elif last_dim_spec.shard_list is not None:
for dim in last_dim_spec.shard_list:
if dim == 0:
reduce_input(out, pg_axis0)
elif dim == 1:
reduce_input(out, pg_axis1)
else:
raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected")
# add bias
if bias is not None:
out += bias
# convert shard seq to partition dict
output_partition_dict = {}
for index, dim_spec in enumerate(output_shard_seq):
if not dim_spec.is_replica:
if index not in output_partition_dict:
output_partition_dict[index] = []
output_partition_dict[index].extend(dim_spec.shard_list)
entire_shape = out.shape
output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict)
ret_tensor = ColoTensor.from_torch_tensor(out)
setattr(ret_tensor, 'sharding_spec', output_sharding_spec)
return ret_tensor
def _has_sharding_spec(tensor):
"""
A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is
set as the attribute `sharding_spec` on a tensor.
"""
return hasattr(tensor, 'sharding_spec')
@colo_op_impl(F.linear)
def colo_linear(input: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
if _has_sharding_spec(weight):
return _new_colo_linear_imp(input, weight, bias)
else:
return colo_linear_imp(input, weight, bias)
from typing import Optional
import torch
import torch.nn.functional as F
from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
from colossalai.tensor import ColoTensor, ColoTensorSpec
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
@colo_op_impl(F.cross_entropy)
def colo_cross_entropy(input_tensor: GeneralTensor,
target: GeneralTensor,
weight: Optional[GeneralTensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0):
assert isinstance(weight, ColoTensor) or isinstance(target, ColoTensor) or isinstance(input_tensor, ColoTensor)
pg = input_tensor.get_process_group() if isinstance(input_tensor, ColoTensor) else isinstance(target, ColoTensor)
weight = convert_to_colo_tensor(weight, pg)
target = convert_to_colo_tensor(target, pg)
input_tensor = convert_to_colo_tensor(input_tensor, pg)
if input_tensor.is_replicate(): # Input is gathered
assert target.is_replicate() and (weight is None or weight.is_replicate()), \
"Target tensor and weight tensor both should be complete"
output = F.cross_entropy(input_tensor,
target,
weight=weight,
size_average=size_average,
ignore_index=ignore_index,
reduce=reduce,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
if input_tensor.is_shard_1dcol():
assert weight is None, "Current TP cross entropy loss function doesn't support passing weight tensor in"
assert target.is_replicate(), "Target tensor should be complete in TP cross entropy loss function"
output = VocabParallelCrossEntropyLoss1D()(input_tensor,
target,
process_group=input_tensor.process_group.tp_process_group())
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
else:
raise NotImplementedError
else:
raise NotImplementedError
import operator
from functools import reduce
from typing import Optional, Union
import torch
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
from colossalai.tensor.op_wrapper import colo_op_impl
def _all_int(my_iter):
return all(isinstance(i, int) for i in my_iter)
def _get_valid_shape(shape):
if isinstance(shape, list):
if _all_int(shape):
return tuple(shape)
else:
raise RuntimeError("expects type(int) but finds an other type")
elif isinstance(shape, tuple):
if _all_int(shape):
return shape
else:
return _get_valid_shape(shape[0])
else:
raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
def _shape_infer(org_sp, tgt_sp):
cnt = 0
pos = 0
for idx, dim in enumerate(tgt_sp):
if dim < -1:
raise RuntimeError("invalid shape dimension {}".format(dim))
elif dim == -1:
cnt += 1
pos = idx
if cnt > 1:
raise RuntimeError("only one dimension can be inferred")
org_prod = reduce(operator.mul, org_sp, 1)
tgt_prod = reduce(operator.mul, tgt_sp, 1)
if cnt == 0:
if org_prod != tgt_prod:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
else:
return tgt_sp
elif org_prod % tgt_prod != 0:
raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
infer_dim = -(org_prod // tgt_prod)
return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
@colo_op_impl(torch.Tensor.view)
def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
Changes the shape of the current tensor.
"""
assert isinstance(self, ColoTensor)
# apply original `view` function for replicated colo tensors
if self.is_replicate():
return self.view(*shape)
cur_sp = self.size()
org_sp = self.size_global()
# parse the passed arguments
tgt_sp = _get_valid_shape(shape)
# get the correct shape from inference
inf_sp = _shape_infer(org_sp, tgt_sp)
if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
new_shape = (cur_sp[0],) + tgt_sp[1:]
res = self.view(*new_shape)
elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
new_shape = tgt_sp[:-1] + (cur_sp[-1],)
res = self.view(*new_shape)
else:
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
spec=ColoTensorSpec(self.get_process_group()))
return ColoTensor.from_torch_tensor(tensor=res,
spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
@colo_op_impl(torch.Tensor.size)
def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
size = self.size_global()
if dim is None:
return size
else:
return size[dim]
...@@ -5,8 +5,8 @@ from contextlib import contextmanager ...@@ -5,8 +5,8 @@ from contextlib import contextmanager
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
class ParallelLayer(nn.Module): class ParallelLayer(nn.Module):
......
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from ..parallel_1d import * from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
try: try:
import fused_mix_prec_layer_norm_cuda import fused_mix_prec_layer_norm_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment