dense.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
Alp Dener's avatar
Alp Dener committed
11
import warnings
12
13
14
15
16
17
from typing import Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
18
19
20
21
from .quantize import (
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
22
    TensorUsage,
23
)
24

25
from .sharding import get_sequence_parallel_dim
26

Alp Dener's avatar
Alp Dener committed
27
28
29
30
31
32
33
34
35
36
DENSE_BATCH_FIRST_WARNING_ISSUED = False


def _issue_batch_first_warning(msg):
    global DENSE_BATCH_FIRST_WARNING_ISSUED
    if not DENSE_BATCH_FIRST_WARNING_ISSUED:
        warnings.warn(msg, UserWarning)
        DENSE_BATCH_FIRST_WARNING_ISSUED = True


37
38
39
40
41
def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
42
43
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
Alp Dener's avatar
Alp Dener committed
44
    batch_first: bool = True,
45
    sequence_parallel_output: bool = False,
46
47
48
49
50
51
52
53
54
55
56
57
58
    quantizer_set: QuantizerSet = noop_quantizer_set,
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
Alp Dener's avatar
Alp Dener committed
59
        batch_first: Assume that X is batched in the first dimension.
60
61
        sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
                                  supported for TE custom GEMM with row-parallel kernel axes.
62
63
64
65
66
67
        quantizer_set: QuantizerSet which contains quantizers for different tensor types

    Returns:
        Transformed output tensor
    """
    # Remove when tex.quantize() can handle quantizer=None
Alp Dener's avatar
Alp Dener committed
68
    if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
69
        x = with_sharding_constraint_by_logical_axes(x, input_axes)
Alp Dener's avatar
Alp Dener committed
70
        output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
71
72
73
74
        if bias is not None:
            bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
            output += jnp.reshape(bias, bias_new_shape)
    else:
Alp Dener's avatar
Alp Dener committed
75
        output = _dense(
76
77
78
79
80
81
82
83
84
            x,
            kernel,
            bias,
            contracting_dims,
            input_axes,
            kernel_axes,
            batch_first,
            sequence_parallel_output,
            quantizer_set,
Alp Dener's avatar
Alp Dener committed
85
        )
86
87
88
    return output


89
90
91
92
93
94
95
96
97
98
99
100
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
def _dense(
    x,
    kernel,
    bias,
    contracting_dims,
    input_axes,
    kernel_axes,
    batch_first,
    sequence_parallel_output,
    quantizer_set,
):
101
102
103
104
105
106
107
108
109
110
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
111
112
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
Alp Dener's avatar
Alp Dener committed
113
        batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
114
115
116
        sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
                                  supported for TE custom GEMM with row-parallel kernel axes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
117
118
119
120

    Returns:
        Transformed output tensor
    """
121
    output, _ = _dense_fwd_rule(
122
123
124
125
126
127
128
129
130
        x,
        kernel,
        bias,
        contracting_dims,
        input_axes,
        kernel_axes,
        batch_first,
        sequence_parallel_output,
        quantizer_set,
131
    )
132
133
134
    return output


Alp Dener's avatar
Alp Dener committed
135
def _dense_fwd_rule(
136
137
138
139
140
141
142
143
144
    x,
    kernel,
    bias,
    contracting_dims,
    input_axes,
    kernel_axes,
    batch_first,
    sequence_parallel_output,
    quantizer_set,
Alp Dener's avatar
Alp Dener committed
145
):
146
147
148
149
150
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
Alp Dener's avatar
Alp Dener committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    x_contracting_dims, k_contracting_dims = map(
        tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
    )

    # Check supported input layout
    x_is_transposed = x.ndim - 1 not in x_contracting_dims
    k_is_transposed = kernel.ndim - 1 in k_contracting_dims
    assert (
        not x_is_transposed and not k_is_transposed
    ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."

    # Determine X batch dimension
    # - If `batch_first=True` -> (batch, leading..., contracting...)
    # - Otherwise             -> (leading..., batch, contracting...)
    # NOTE: Always assume a single batch dimension
    x_bdim = None
    num_cdims = len(x_contracting_dims)
    if x.ndim >= num_cdims + 2:
        # Assume X is batched if it has at least +2 dimensions more than the number of contracting
        # dimensions.
        if not batch_first:
            _issue_batch_first_warning(
                "TE/JAX `dense()` layer implementation does not officially support sequence-first "
                "inputs and may produce incorrect results when `batch_first=False`. Use "
                "sequence-first inputs at your own discretion.",
            )
        x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
178

179
180
181
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

Alp Dener's avatar
Alp Dener committed
182
183
184
    casted_x = tex.quantize(
        x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
    )
185
186
187
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
Alp Dener's avatar
Alp Dener committed
188
189
190
191
        kernel,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.kernel,
        noop_scaled_tensor=True,
192
193
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
194
195

    # GEMM NN
Alp Dener's avatar
Alp Dener committed
196
    use_bias = bias is not None
197
    output = tex.gemm(
198
199
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
200
201
202
203
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        batched_dims=((x_bdim,), ()),
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
204
        sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(),
205
    )
206

Alp Dener's avatar
Alp Dener committed
207
    if use_bias and tex.gemm_uses_jax_dot():
208
209
210
211
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
212
213
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
214
215
216
217
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
218
        flatten_axis_k,
Alp Dener's avatar
Alp Dener committed
219
        x_bdim,
220
221
222
223
    )
    return output, ctx


224
def _dense_bwd_rule(
225
    contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad
226
):  # pylint: disable=unused-argument
227
228
229
230
231
232
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    (
233
234
        casted_x_lhs,
        casted_kernel_rhs,
235
236
237
238
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
239
        flatten_axis_k,
Alp Dener's avatar
Alp Dener committed
240
        x_bdim,
241
242
    ) = ctx

Alp Dener's avatar
Alp Dener committed
243
244
245
246
    fwd_x_contracting_dims, fwd_k_contracting_dims = map(
        tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
    )

247
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
248
249
250
251
252
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.dgrad,
        noop_scaled_tensor=True,
253
    )
254
255
256

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
257
    g_contracting_dim = tuple(
258
259
260
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
261
    k_contracting_dim = tuple(
262
263
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
264
265
266

    # Get sequence-parallel dimension of the FWD input (if it exists)
    sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,))
267
    dgrad = tex.gemm(
268
269
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
Alp Dener's avatar
Alp Dener committed
270
271
        contracting_dims=(g_contracting_dim, k_contracting_dim),
        batched_dims=((x_bdim,), ()),
272
273
274
275
276
277
278
279
        sequence_parallel_output=(
            sequence_dim is not None
            and not sequence_parallel_output
            and not tex.gemm_uses_jax_dot()
        ),
        sequence_dim=(
            None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim
        ),
280
    )
281
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
282
283
284

    # GEMM TN
    # x_non_contracting_dims
285
    g_contracting_dim = x_contracting_dim = tuple(
286
287
288
289
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
290
291
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
292
293
        contracting_dims=(x_contracting_dim, g_contracting_dim),
        batched_dims=((x_bdim,), (x_bdim,)),
294
    )
295
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
296
297
298
299
300
301
302
303

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
304
305
306
307
308
309
310
311
312
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
313
):
314
315
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
342
    )
343
    return output
344
345


346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7))
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
368
    )
369
    return output
370
371
372


def _grouped_dense_fwd_rule(
373
374
375
376
377
378
379
380
381
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
382
):
383
384
385
386
387
388
389
390
391
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
392
    else:
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
            x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
        )
        casted_kernel = tex.grouped_quantize(
            kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
418
419
420
421
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
        ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
422
423
424
425
426
427
428
429
430
431

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
432
433
434
    )

    ctx = (
435
436
437
438
439
        group_sizes,
        ctx_x,
        ctx_kernel,
        x.shape,
        kernel.shape,
440
        use_bias,
441
442
443
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
444
    )
445
    return output, ctx
446
447


448
449
450
451
452
def _grouped_dense_bwd_rule(
    contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

453
    (
454
455
456
457
458
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
459
        use_bias,
460
461
462
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
463
464
    ) = ctx

465
466
467
468
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
469
        g_contracting_dim = tuple(
470
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
471
472
        )
        k_contracting_dim = tuple(
473
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
474
475
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
476
477
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
478

479
480
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
481
482
483
484
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
499
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
500
501
        dgrad_kernel_T = ctx_kernel

502
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
503
504
505
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
506
        x_contracting_dim = (0,)
507
508
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
509
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
510
511
512
513
514
515
516
517
518
519

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
520

521
522
523
524
525
526
527
528
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
529
530
    )

531
532
533
534
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None

    return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
535
536
537


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)