dense.py 13.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""

from typing import Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
18
19
20
21
from .quantize import (
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
22
    TensorUsage,
23
)
24
25
26
27
28
29
30


def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
31
32
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    quantizer_set: QuantizerSet = noop_quantizer_set,
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
        quantizer_set: QuantizerSet which contains quantizers for different tensor types

    Returns:
        Transformed output tensor
    """
    # Remove when tex.quantize() can handle quantizer=None
    if quantizer_set == noop_quantizer_set:
53
        x = with_sharding_constraint_by_logical_axes(x, input_axes)
54
55
56
57
58
        output = tex.gemm(x, kernel, contracting_dims)
        if bias is not None:
            bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
            output += jnp.reshape(bias, bias_new_shape)
    else:
59
        output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
60
61
62
    return output


63
64
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
65
66
67
68
69
70
71
72
73
74
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
75
76
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
77
78
79
80
81
        quantizer_set: QuantizerSet which contains quantizers for different tensor types

    Returns:
        Transformed output tensor
    """
82
83
84
    output, _ = _dense_fwd_rule(
        x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
    )
85
86
87
    return output


88
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
89
90
91
92
93
94
95
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
    x_contracting_dims, k_contracting_dims = contracting_dims

96
97
98
99
100
101
102
103
104
105
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

    casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
        kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
106
107
108

    # GEMM NN
    output = tex.gemm(
109
110
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
111
112
        (x_contracting_dims, k_contracting_dims),
    )
113

114
115
116
117
118
119
    use_bias = bias is not None
    if use_bias:
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
120
121
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
122
123
124
125
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
126
        flatten_axis_k,
127
128
129
130
    )
    return output, ctx


131
132
133
def _dense_bwd_rule(
    contracting_dims, input_axes, kernel_axes, ctx, grad
):  # pylint: disable=unused-argument
134
135
136
137
138
139
140
141
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

    (
142
143
        casted_x_lhs,
        casted_kernel_rhs,
144
145
146
147
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
148
        flatten_axis_k,
149
150
    ) = ctx

151
152
153
    casted_grad, dbias = tex.quantize_dbias(
        grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
    )
154
155
156

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
157
    g_contracting_dim = tuple(
158
159
160
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
161
    k_contracting_dim = tuple(
162
163
164
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
    dgrad = tex.gemm(
165
166
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
167
        (g_contracting_dim, k_contracting_dim),
168
    )
169
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
170
171
172

    # GEMM TN
    # x_non_contracting_dims
173
    g_contracting_dim = x_contracting_dim = tuple(
174
175
176
177
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
178
179
180
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
        (x_contracting_dim, g_contracting_dim),
181
    )
182
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
183
184
185
186
187
188
189
190

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
191
192
193
194
195
196
197
198
199
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
200
):
201
202
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
229
    )
230
    return output
231
232


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7))
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
255
    )
256
    return output
257
258
259


def _grouped_dense_fwd_rule(
260
261
262
263
264
265
266
267
268
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
269
):
270
271
272
273
274
275
276
277
278
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
279
    else:
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
            x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
        )
        casted_kernel = tex.grouped_quantize(
            kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
305
306
307
308
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
        ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
309
310
311
312
313
314
315
316
317
318

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
319
320
321
    )

    ctx = (
322
323
324
325
326
        group_sizes,
        ctx_x,
        ctx_kernel,
        x.shape,
        kernel.shape,
327
        use_bias,
328
329
330
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
331
    )
332
    return output, ctx
333
334


335
336
337
338
339
def _grouped_dense_bwd_rule(
    contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

340
    (
341
342
343
344
345
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
346
        use_bias,
347
348
349
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
350
351
    ) = ctx

352
353
354
355
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
356
        g_contracting_dim = tuple(
357
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
358
359
        )
        k_contracting_dim = tuple(
360
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
361
362
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
363
364
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
365

366
367
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
368
369
370
371
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
386
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
387
388
        dgrad_kernel_T = ctx_kernel

389
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
390
391
392
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
393
        x_contracting_dim = (0,)
394
395
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
396
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
397
398
399
400
401
402
403
404
405
406

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
407

408
409
410
411
412
413
414
415
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
416
417
    )

418
419
420
421
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None

    return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
422
423
424


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)