softmax.py 2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.
"""JAX softmax modules"""
from enum import Enum
from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp

12
from . import cpp_extensions as tex
13
14
15
16


class SoftmaxType(Enum):
    """SoftmaxType."""
17

18
19
20
21
22
    SCALED = "scaled"
    SCALED_MASKED = "scaled_masked"
    SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"


23
24
25
26
27
28
def softmax(
    logits: jnp.ndarray,
    mask: Optional[jnp.ndarray] = None,
    scale_factor: Optional[float] = 1.0,
    softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED,
):
29
30
31
    """
    Softmax wrapper
    """
32
33
    output = _softmax(logits, mask, scale_factor, softmax_type)
    return output
34
35
36


@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
37
38
39
def _softmax(logits, mask, scale_factor, softmax_type):

    output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type)
40
41
42
    return output


43
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
44
45
    if softmax_type is SoftmaxType.SCALED_MASKED:
        assert mask is not None
46
        output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor)
47
    elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
48
        output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
49
    else:
50
        output = tex.scaled_softmax_fwd(logits, scale_factor)
51

52
    return output, (output, logits, mask)
53
54


55
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
56
    (softmax_output, logits, mask) = ctx
57
58

    if softmax_type is SoftmaxType.SCALED_MASKED:
59
        dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor)
60
    elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
61
        dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor)
62
    else:
63
        dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor)
64
65
66
67

    return (dgrad, None)


68
_softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule)