softmax.py 3.25 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#
# See LICENSE for license information.
"""JAX softmax modules"""
from enum import Enum
from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp

from .cpp_extensions import scaled_softmax_fwd
from .cpp_extensions import scaled_softmax_bwd
from .cpp_extensions import scaled_masked_softmax_fwd
from .cpp_extensions import scaled_masked_softmax_bwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_fwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive


class SoftmaxType(Enum):
    """SoftmaxType."""
    SCALED = "scaled"
    SCALED_MASKED = "scaled_masked"
    SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"


def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int,
                                k_seqlen: int, dtype: jnp.dtype):
    """check softmax available"""
    if softmax_type is SoftmaxType.SCALED:
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)
    if softmax_type is SoftmaxType.SCALED_MASKED:
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                                   dtype)
    if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype)

    raise NotImplementedError


46
def softmax(logits: jnp.ndarray,
47
48
            mask: Optional[jnp.ndarray] = None,
            scale_factor: Optional[float] = 1.0,
49
            softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED):
50
51
52
    """
    Softmax wrapper
    """
53
54
    output = _softmax(logits, mask, scale_factor, softmax_type)
    return output
55
56
57


@partial(jax.custom_vjp, nondiff_argnums=(2, 3))
58
59
60
def _softmax(logits, mask, scale_factor, softmax_type):

    output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type)
61
62
63
    return output


64
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
65
66
    if softmax_type is SoftmaxType.SCALED_MASKED:
        assert mask is not None
67
        output = scaled_masked_softmax_fwd(logits, mask, scale_factor)
68
    elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
69
        output = scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
70
    else:
71
        output = scaled_softmax_fwd(logits, scale_factor)
72

73
    return output, (output,)
74
75


76
77
def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
    softmax_output, = ctx
78
79

    if softmax_type is SoftmaxType.SCALED_MASKED:
80
        dgrad = scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
81
    elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
82
        dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
83
    else:
84
        dgrad = scaled_softmax_bwd(dz, softmax_output, scale_factor)
85
86
87
88

    return (dgrad, None)


89
_softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule)