addmm.py 4.06 KB
Newer Older
1
import torch
2
3

from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
4
from colossalai.tensor.op_wrapper import colo_op_impl
5
6

from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input
7
8


ver217's avatar
ver217 committed
9
10
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
                     alpha: Number) -> ColoTensor:
11
12
13
    # mat1:S[1] x mat2:S[0] = Output:P
    # beta * input + alpha * All-Reduce(Output) = res

14
    mat1 = mat1.redistribute(ShardSpec([-1], [mat2.get_tp_world_size()]), mat2.get_process_group())
15

16
    # Output:P
ver217's avatar
ver217 committed
17
    partial_output = torch.mm(mat1, mat2)
18
    # Reduce(Output)
19
    output = reduce_input(partial_output, mat2.get_process_group())
20
    # input
21
    assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
ver217's avatar
ver217 committed
22
    output = beta * input_tensor + alpha * output
23
    output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(input_tensor.get_process_group()))
24
25
26
    return output


ver217's avatar
ver217 committed
27
28
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
                     alpha: Number) -> ColoTensor:
29
    # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
30
    compute_spec = mat2.compute_spec
31
    mat1 = mat1.redistribute(ReplicaSpec())
32
    mat1 = reduce_grad(mat1, mat1.get_process_group())
33

ver217's avatar
ver217 committed
34
    output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
35
    output_spec = ColoTensorSpec(input_tensor.get_process_group(), ShardSpec([-1], [mat2.get_tp_world_size()]),
36
                                 ComputeSpec(ComputePattern.TP1D))
ver217's avatar
ver217 committed
37
    output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
38

39
40
41
42
    if compute_spec.output_replicate:
        return output.to_replicate()
    else:
        return output
43
44


ver217's avatar
ver217 committed
45
46
47
48
49
50
51
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
                  alpha: Number) -> ColoTensor:
    assert mode in ('row', 'col')
    funcs = {'row': colo_addmm_1Drow, 'col': colo_addmm_1Dcol}
    return funcs[mode](input_tensor, mat1, mat2, beta, alpha)


52
@colo_op_impl(torch.addmm)
ver217's avatar
ver217 committed
53
def colo_addmm(input_tensor: GeneralTensor,
54
55
               mat1: ColoTensor,
               mat2: ColoTensor,
ver217's avatar
ver217 committed
56
               beta: Number = 1,
57
               alpha: Number = 1,
58
               **kargs) -> ColoTensor:
59
60
61
    """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
    This method computes a linear.
    """
62
63
64
65
    # At least one of the tensor should be ColoTensor
    assert isinstance(mat2, ColoTensor)
    input_tensor = convert_to_colo_tensor(input_tensor, mat2.get_process_group())
    mat1 = convert_to_colo_tensor(mat1, mat2.get_process_group())
66
67
68

    # Add communication logic before and after linear call.
    ret_tensor = None
69
    if not mat2.has_compute_spec():    # No Model Parallel Applied
70
71
        assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op'
        assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op'
72
73
74
75
76
77
78
        ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor,
                                                                     mat1,
                                                                     mat2,
                                                                     beta=beta,
                                                                     alpha=alpha,
                                                                     **kargs),
                                                  spec=ColoTensorSpec(mat2.get_process_group()))
79
80
    elif mat2.has_compute_pattern(ComputePattern.TP1D):    # Single Model Parallel Applied
        if mat2.is_shard_1drow() and input_tensor.is_replicate():
ver217's avatar
ver217 committed
81
            mode = 'row'
82
        elif mat2.is_shard_1dcol() and (input_tensor.is_shard_1dcol() or input_tensor.is_shard_1drow()):
ver217's avatar
ver217 committed
83
            mode = 'col'
84
85
        else:
            raise NotImplementedError
ver217's avatar
ver217 committed
86
        ret_tensor = colo_addmm_1d(mode, input_tensor, mat1, mat2, beta, alpha)
87
88
89
90
    else:
        raise NotImplementedError

    return ret_tensor