softmax.py 1.96 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
#
# See LICENSE for license information.
"""JAX softmax modules"""
from enum import Enum
from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp

12
from . import cpp_extensions as tex
13
14
15
16
17
18
19
20
21


class SoftmaxType(Enum):
    """SoftmaxType."""
    SCALED = "scaled"
    SCALED_MASKED = "scaled_masked"
    SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"


22
def softmax(logits: jnp.ndarray,
23
24
            mask: Optional[jnp.ndarray] = None,
            scale_factor: Optional[float] = 1.0,
25
            softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED):
26
27
28
    """
    Softmax wrapper
    """
29
30
    output = _softmax(logits, mask, scale_factor, softmax_type)
    return output
31
32
33


@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
34
35
36
def _softmax(logits, mask, scale_factor, softmax_type):

    output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type)
37
38
39
    return output


40
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
41
42
    if softmax_type is SoftmaxType.SCALED_MASKED:
        assert mask is not None
43
        output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor)
44
    elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
45
        output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
46
    else:
47
        output = tex.scaled_softmax_fwd(logits, scale_factor)
48

49
    return output, (output,)
50
51


52
53
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
    softmax_output, = ctx
54
55

    if softmax_type is SoftmaxType.SCALED_MASKED:
56
        dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
57
    elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
58
        dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
59
    else:
60
        dgrad = tex.scaled_softmax_bwd(dz, softmax_output, scale_factor)
61
62
63
64

    return (dgrad, None)


65
_softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule)