"...git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "d9de51339e8ff7a4d8609c6253dc41a6fdc88a50"
linear.py 4.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for linear layer."""

from __future__ import annotations
from collections.abc import Callable
from typing import Optional

import torch

from transformer_engine.pytorch.ops.basic import (
    AllReduce,
    BasicLinear,
    Bias,
    ReduceScatter,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.ops.op import FusedOperation


class Linear(FusedOperation):
    """Apply linear transformation: :math:`y = x A^T + b`

    This is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features: int
        Inner dimension of input tensor
    out_features: int
        Inner dimension of output tensor
    bias: bool, default = `True`
        Apply additive bias
    device: torch.device, default = default CUDA device
        Tensor device
    dtype: torch.dtype, default = default dtype
        Tensor datatype
    tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
        Mode for tensor parallelism
    tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
        Process group for tensor parallelism
    sequence_parallel: bool, default = `False`
        Whether to apply sequence parallelism together with tensor
        parallelism, i.e. distributing input or output tensors along
        outer dimension (sequence or batch dim) when not distributing
        along inner dimension (embedding dim)
    rng_state_tracker_function: callable
        Function that returns CudaRNGStatesTracker, which is used for
        model-parallel weight initialization
    accumulate_into_main_grad: bool, default = `False`
        Whether to directly accumulate weight gradients into the
        weight's `main_grad` attribute instead of relying on PyTorch
        autograd. The weight's `main_grad` must be set externally and
        there is no guarantee that `grad` will be set or be
        meaningful.

    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        bias: bool = True,
        device: Optional[torch.device | str] = None,
        dtype: Optional[torch.dtype] = None,
        tensor_parallel_mode: Optional[str] = None,
        tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
        sequence_parallel: bool = False,
        rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
        accumulate_into_main_grad: bool = False,
    ) -> None:

        # Tensor parallel configuration
        (
            tensor_parallel_mode,
            tensor_parallel_group,
            tensor_parallel_size,
            sequence_parallel,
            local_in_features,
            local_out_features,
        ) = BasicLinear._canonicalize_tensor_parallelism(
            mode=tensor_parallel_mode,
            process_group=tensor_parallel_group,
            sequence_parallel=sequence_parallel,
            in_features=in_features,
            out_features=out_features,
        )

        # Construct basic ops
        ops = []
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        linear_kwargs = {
            "in_features": in_features,
            "out_features": out_features,
            "device": device,
            "dtype": dtype,
            "tensor_parallel_mode": tensor_parallel_mode,
            "tensor_parallel_group": tensor_parallel_group,
            "sequence_parallel": sequence_parallel,
            "rng_state_tracker_function": rng_state_tracker_function,
            "accumulate_into_main_grad": accumulate_into_main_grad,
        }
        bias_kwargs = {
            "size": out_features,
            "device": device,
            "dtype": dtype,
            "tensor_parallel": (tensor_parallel_mode is not None),
            "tensor_parallel_group": tensor_parallel_group,
        }
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        if tensor_parallel_mode == "row":
            # Row TP: GEMM + bias + reduction
            linear_kwargs["in_features"] = local_in_features
            linear_kwargs["out_features"] = local_out_features
            linear_kwargs["tensor_parallel_mode"] = None
            linear_kwargs["tensor_parallel_group"] = None
            linear_kwargs["sequence_parallel"] = False
            bias_kwargs["size"] *= tensor_parallel_size
            ops.append(BasicLinear(**linear_kwargs))
            if bias:
                ops.append(Bias(**bias_kwargs))
            if sequence_parallel:
                ops.append(ReduceScatter(tensor_parallel_group))
            else:
                ops.append(AllReduce(tensor_parallel_group))
        else:
            # Column TP or no TP: (gather + GEMM) + bias
            ops.append(BasicLinear(**linear_kwargs))
            if bias:
                ops.append(Bias(**bias_kwargs))

        # Initialize base class
        super().__init__(ops)

        # Register parameters
        self.register_parameter("weight", self.basic_ops[0].weight)
        self.register_parameter("bias", self.basic_ops[1].bias if bias else None)