gemm.py 4.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""GEMM API for experimental middleware between Transformer Engine and Kitchen."""

from typing import Iterable, Optional

import torch

from transformer_engine.pytorch.experimental.quantization import (
    MMParams,
    GEMMType,
)
15
16
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.tensor.utils import is_experimental
17
18
19


def experimental_gemm(
20
21
    A: QuantizedTensorStorage,
    B: QuantizedTensorStorage,
22
23
24
25
26
27
28
29
30
31
32
33
34
    workspace: torch.Tensor,  # pylint: disable=unused-argument
    out_dtype: Optional[torch.dtype] = None,
    quantization_params: Optional[Quantizer] = None,  # pylint: disable=unused-argument
    gelu: bool = False,  # pylint: disable=unused-argument
    gelu_in: torch.Tensor = None,  # pylint: disable=unused-argument
    accumulate: bool = False,  # pylint: disable=unused-argument
    layout: str = "TN",
    out: Optional[torch.Tensor] = None,  # pylint: disable=unused-argument
    bias: Optional[torch.Tensor] = None,
    use_split_accumulator: bool = False,
    grad: bool = False,
) -> Iterable[Optional[torch.Tensor]]:
    """Dispatch GEMM to quantizer's qgemm method."""
35
    assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors"
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    A, B = B, A

    # Determine GEMM type based on grad flag and layout
    if not grad:
        gemm_type = GEMMType.FPROP
    else:
        if layout == "NN":
            gemm_type = GEMMType.DGRAD
        elif layout == "NT":
            gemm_type = GEMMType.WGRAD
        else:
            # Default to FPROP for other layouts
            gemm_type = GEMMType.FPROP

    # Extract quantizer from QuantizedTensor to get qgemm logic
52
    # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B._quantizer?
53
    quantizer = None
54
55
56
57
    if hasattr(A, "_quantizer") and A._quantizer is not None:
        quantizer = A._quantizer
    elif hasattr(B, "_quantizer") and B._quantizer is not None:
        quantizer = B._quantizer
58
    else:
59
        raise ValueError("No quantizer found in QuantizedTensor objects")
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

    # Create MMParams
    m_params = MMParams(
        out_dtype=out_dtype,
        use_split_accumulator=use_split_accumulator,
    )
    out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype

    if gemm_type == GEMMType.FPROP:
        qx, sx = A.data, A.scale
        qw, sw = B.data, B.scale
        assert qx is not None
        assert sx is not None
        assert qw is not None
        assert sw is not None
        assert A.original_shape is not None

        # Call quantizer's qgemm method
        result = quantizer.qgemm(
            qx,
            qw,
            m_params,
            out_dtype,
            sx,
            sw,
            bias,
            gemm_type=GEMMType.FPROP,
            qresult_x=A,
            qresult_w=B,
        )
        if len(A.original_shape) > 2:
            # Original input was 3D, so we need to reshape result back to 3D
            batch_size = A.original_shape[0]
            seq_len = A.original_shape[1]
            result = result.view(batch_size, seq_len, result.shape[-1])
    elif gemm_type == GEMMType.DGRAD:
        qdy, sdy = A.data, A.scale
        qw_t, sw_t = B.data_t, B.scale_t
        assert qdy is not None
        assert sdy is not None
        assert qw_t is not None
        assert sw_t is not None

        result = quantizer.qgemm(
            qdy,
            qw_t,
            m_params,
            out_dtype,
            sdy,
            sw_t,
            None,
            gemm_type=GEMMType.DGRAD,
            qresult_x=A,
            qresult_w=B,
        )
    elif gemm_type == GEMMType.WGRAD:
        qdy_t, sdy_t = A.data_t, A.scale_t
        qx_t, sx_t = B.data_t, B.scale_t
        assert qdy_t is not None
        assert sdy_t is not None
        assert qx_t is not None
        assert sx_t is not None

        result = quantizer.qgemm(
            qdy_t,
            qx_t,
            m_params,
            out_dtype,
            sdy_t,
            sx_t,
            None,
            gemm_type=GEMMType.WGRAD,
            qresult_x=A,
            qresult_w=B,
        )

    # Return in the same format as general_gemm
    return result, None, None, None