activation.py 3.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.

This module provides optimized activation functions with quantization support.
"""

from typing import Sequence, Union, Callable, Optional
from functools import partial

import jax
import jax.numpy as jnp
from . import cpp_extensions as tex

16
from .quantize.tensor import NoScaleTensor
17
18
19
20
21
22
23
from .quantize.quantizer import Quantizer


def activation(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
24
    act_params: Optional[tex.activation.ActivationParams] = None,
25
) -> jnp.ndarray:
26
27
28
29
30
31
32
33
34
    """Apply activation functions to input tensor with optional quantization.

    This function applies a sequence of activation functions to the input tensor.
    It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).

    Args:
        x: Input tensor to apply activations to
        activation_type: Sequence of activation functions
        quantizer: Optional quantizer for quantizing the output
35
36
        act_params: Optional activation parameters. Currently used
        just for ClampedSwiGLU.
37
38
39
40
41

    Returns:
        Activated output tensor
    """
    assert x.shape[-1] % len(activation_type) == 0
42
    output = _activation(x, activation_type, quantizer, act_params)
43
44
45
    return output


46
47
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer, act_params):
48
49
50
51
52
53
54
55
56
    """Internal implementation of activation with custom VJP.

    This function implements the core activation logic with support for
    custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        activation_type: Sequence of activation functions
        quantizer: Optional quantizer
57
58
        act_params: Optional activation parameters. Currently used
        just for ClampedSwiGLU.
59
60
61
62

    Returns:
        Activated tensor
    """
63
    _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
64
65
66
    return _output


67
def _activation_fwd_rule(x, activation_type, quantizer, act_params):
68
69
70
71
72
73
    """Forward pass rule for activation function.

    Args:
        x: Input tensor
        activation_type: Sequence of activation functions
        quantizer: Optional quantizer
74
75
        act_params: Optional activation parameters. Currently used
        just for ClampedSwiGLU.
76
77
78
79

    Returns:
        Tuple of (output, context) for backward pass
    """
80
    fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
81
82
    # This is a no-op for higher-precision tensors
    fwd_output = fwd_output.dequantize()
83
84
85
    return fwd_output, (x, quantizer)


86
def _activation_bwd_rule(activation_type, act_params, ctx, g):
87
88
89
90
    """Backward pass rule for activation function.

    Args:
        activation_type: Sequence of activation functions
91
92
        act_params: Optional activation parameters. Currently used
        just for ClampedSwiGLU.
93
94
95
96
97
98
99
100
        ctx: Context from forward pass
        g: Gradient from upstream

    Returns:
        Gradient with respect to input
    """
    (x, _) = ctx
    assert x.dtype == g.dtype
101
    dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
102
103
104
105
    # No quantization is used in this VJP backward, so the output should
    # always be a NoScaleTensor
    assert isinstance(dx, NoScaleTensor)
    dx = dx.data
106
107
108
109
    return (dx, None)


_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)