dense.py 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""

from typing import Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
18
19
20
21
22
from .quantize import (
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
)
23
24
25
26
27
28
29


def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
30
31
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    quantizer_set: QuantizerSet = noop_quantizer_set,
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
        quantizer_set: QuantizerSet which contains quantizers for different tensor types

    Returns:
        Transformed output tensor
    """
    # Remove when tex.quantize() can handle quantizer=None
    if quantizer_set == noop_quantizer_set:
52
        x = with_sharding_constraint_by_logical_axes(x, input_axes)
53
54
55
56
57
        output = tex.gemm(x, kernel, contracting_dims)
        if bias is not None:
            bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
            output += jnp.reshape(bias, bias_new_shape)
    else:
58
        output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
59
60
61
    return output


62
63
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
64
65
66
67
68
69
70
71
72
73
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
74
75
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
76
77
78
79
80
        quantizer_set: QuantizerSet which contains quantizers for different tensor types

    Returns:
        Transformed output tensor
    """
81
82
83
    output, _ = _dense_fwd_rule(
        x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
    )
84
85
86
    return output


87
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
88
89
90
91
92
93
94
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
    x_contracting_dims, k_contracting_dims = contracting_dims

95
96
97
98
99
100
101
102
103
104
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

    casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
        kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
105
106
107
108
109
110
111

    # GEMM NN
    output = tex.gemm(
        casted_x.get_rowwise_tensor(),
        casted_kernel.get_colwise_tensor(),
        (x_contracting_dims, k_contracting_dims),
    )
112

113
114
115
116
117
118
119
120
121
122
123
124
    use_bias = bias is not None
    if use_bias:
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
        casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
        casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
125
        flatten_axis_k,
126
127
128
129
    )
    return output, ctx


130
131
132
def _dense_bwd_rule(
    contracting_dims, input_axes, kernel_axes, ctx, grad
):  # pylint: disable=unused-argument
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

    (
        colwise_casted_x,
        rowwise_casted_kernel,
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
147
        flatten_axis_k,
148
149
    ) = ctx

150
151
152
    casted_grad, dbias = tex.quantize_dbias(
        grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
    )
153
154
155

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
156
    g_constracting_dim = tuple(
157
158
159
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
160
    k_constracting_dim = tuple(
161
162
163
164
165
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
    dgrad = tex.gemm(
        casted_grad.get_rowwise_tensor(),
        rowwise_casted_kernel,
166
        (g_constracting_dim, k_constracting_dim),
167
    )
168
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
169
170
171

    # GEMM TN
    # x_non_contracting_dims
172
    g_constracting_dim = x_constracting_dim = tuple(
173
174
175
176
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
177
        colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
178
    )
179
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
180
181
182
183
184
185
186

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


187
"""
188
def grouped_dense(
189
190
191
192
193
    x_list,
    kernel_list,
    bias_list,
    contracting_dims_list,
    quantizer_set_list=None,
194
):
195
    # Perform grouped_dense layer transformation with optional quantization.
196

197
198
    output_list = _grouped_dense(
        x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
199
    )
200
    return output_list
201
202


203
204
205
206
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
    output_list, _ = _grouped_dense_fwd_rule(
        x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
207
    )
208
    return output_list
209
210
211


def _grouped_dense_fwd_rule(
212
    x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
213
):
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    use_bias = bias_list is not None
    output_list = []
    x_rowwise_list = []
    x_colwise_list = []
    kernel_colwise_list = []
    kernel_rowwise_list = []
    x_shape_list = []
    kernel_shape_list = []
    if quantizer_set_list is None:
        x_rowwise_list = x_list
        x_colwise_list = x_list
        kernel_colwise_list = kernel_list
        kernel_rowwise_list = kernel_list
        x_shape_list = [x.shape for x in x_list]
        kernel_shape_list = [kernel.shape for kernel in kernel_list]
229
    else:
230
231
232
233
234
235
236
237
238
239
240
241
        for i in range(len(x_list)):  # pylint: disable=consider-using-enumerate
            q_x = tex.quantize(x_list[i], quantizer_set_list[i].x)
            q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel)
            x_rowwise_list.append(q_x.get_rowwise_tensor())
            x_colwise_list.append(q_x.get_colwise_tensor())
            kernel_colwise_list.append(q_kernel.get_colwise_tensor())
            kernel_rowwise_list.append(q_kernel.get_rowwise_tensor())
            x_shape_list.append(x_rowwise_list[-1].data.shape)
            kernel_shape_list.append(kernel_rowwise_list[-1].data.shape)

    output_list = tex.grouped_gemm(
        x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list
242
243
244
    )

    ctx = (
245
246
247
248
        x_colwise_list,
        kernel_rowwise_list,
        x_shape_list,
        kernel_shape_list,
249
        use_bias,
250
        quantizer_set_list,
251
    )
252
    return output_list, ctx
253
254


255
def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
256
    (
257
258
259
260
        colwise_x_list,
        rowwise_kernel_list,
        x_shape_list,
        kernel_shape_list,
261
        use_bias,
262
        quantizer_set_list,
263
264
    ) = ctx

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    group_size = len(grad_list)
    dbias_list = []
    grad_rowwise_list = []
    grad_colwise_list = []
    dgrad_contracting_dims_list = []
    wgrad_contracting_dims_list = []
    for i in range(group_size):
        grad = grad_list[i]
        x_shape = x_shape_list[i]
        kernel_shape = kernel_shape_list[i]
        fwd_contracting_dims = contracting_dims_list[i]

        if quantizer_set_list is None:
            casted_grad = grad
            dbias = tex.quantization._jax_dbias(grad)
            grad_rowwise_list.append(grad)
            grad_colwise_list.append(grad)
        else:
            quantizer_set = quantizer_set_list[i]
            casted_grad, dbias = tex.quantize_dbias(
                grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
            )
            grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
            grad_colwise_list.append(casted_grad.get_colwise_tensor())
        dbias_list.append(dbias)

        # GEMM NT
        fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
293
        g_contracting_dim = tuple(
294
            range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
295
296
        )
        k_contracting_dim = tuple(
297
            dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
298
299
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
300
        dgrad_contracting_dims_list.append(dgrad_contracting_dims)
301

302
        # GEMM TN
303
304
305
306
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
307
        wgrad_contracting_dims_list.append(wgrad_contracting_dims)
308

309
310
    dgrad_list = tex.grouped_gemm(
        grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list
311
    )
312
    wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
313

314
    return dgrad_list, wgrad_list, dbias_list, quantizer_set_list
315
316
317


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
318
"""