activation.py 2.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.

This module provides optimized activation functions with quantization support.
"""

from typing import Sequence, Union, Callable, Optional
from functools import partial

import jax
import jax.numpy as jnp

from . import cpp_extensions as tex

17
from .quantize.tensor import NoScaleTensor
18
19
20
21
22
23
24
from .quantize.quantizer import Quantizer


def activation(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
25
) -> jnp.ndarray:
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    """Apply activation functions to input tensor with optional quantization.

    This function applies a sequence of activation functions to the input tensor.
    It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).

    Args:
        x: Input tensor to apply activations to
        activation_type: Sequence of activation functions
        quantizer: Optional quantizer for quantizing the output

    Returns:
        Activated output tensor
    """
    assert x.shape[-1] % len(activation_type) == 0
    output = _activation(x, activation_type, quantizer)
    return output


@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
    """Internal implementation of activation with custom VJP.

    This function implements the core activation logic with support for
    custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        activation_type: Sequence of activation functions
        quantizer: Optional quantizer

    Returns:
        Activated tensor
    """
    _output, _ = _activation_fwd_rule(x, activation_type, quantizer)
    return _output


def _activation_fwd_rule(x, activation_type, quantizer):
    """Forward pass rule for activation function.

    Args:
        x: Input tensor
        activation_type: Sequence of activation functions
        quantizer: Optional quantizer

    Returns:
        Tuple of (output, context) for backward pass
    """
    fwd_output = tex.act_lu(x, activation_type, quantizer)
75
76
    # This is a no-op for higher-precision tensors
    fwd_output = fwd_output.dequantize()
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    return fwd_output, (x, quantizer)


def _activation_bwd_rule(activation_type, ctx, g):
    """Backward pass rule for activation function.

    Args:
        activation_type: Sequence of activation functions
        ctx: Context from forward pass
        g: Gradient from upstream

    Returns:
        Gradient with respect to input
    """
    (x, _) = ctx
    assert x.dtype == g.dtype
    dx = tex.dact_lu(g, x, activation_type)
94
95
96
97
    # No quantization is used in this VJP backward, so the output should
    # always be a NoScaleTensor
    assert isinstance(dx, NoScaleTensor)
    dx = dx.data
98
99
100
101
    return (dx, None)


_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)