quantization.py 6.63 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
from typing import Tuple
6
from packaging import version
7

8
import jax
9
10
11
12
13
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

14
15
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
16
17
18
19
20
21
22

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
23
    jax_dtype_to_ir_dtype,
24
    is_ffi_enabled,
25
26
27
)
from ..sharding import all_reduce_max_along_all_axes_except_PP

28
29
30
31
32
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

33

34
__all__ = ["cast_fp8"]
35
36


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _jax_quantize(x, scale, q_dtype):
    """
    Quantize with scale
    """
    compute_dtype = scale.dtype
    dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype)
    scaled_x = x.astype(compute_dtype) * scale
    clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
    return clipped_scaled_x.astype(q_dtype)


def _jax_cast_fp8(inputs, scale, amax, out_dtype):
    """
    JAX native fp8 casting implementation
    """
    casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype)
    updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
    return casted_output, updated_amax


57
58
59
60
class CastFP8Primitive(BasePrimitive):
    """
    Cast Primitive
    """
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    name = "te_quantize"
    multiple_results = True
    impl_static_args = (4,)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
        """
        te_cast abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, updated_amax_aval

    @staticmethod
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
        """
        te_cast lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        if is_ffi_enabled():
            name = "te_quantize_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
                ctx, x, amax, scale, scale_inv
            )
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape

            out_types = [
                ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [x, amax, scale, scale_inv]
            operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_common_descriptor(
                ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype)
            )

            out = custom_caller(
                CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
            )
124
125
126
127
128
129
130
131
132

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype):
        """
        te_cast implementation
        """
        assert CastFP8Primitive.inner_primitive is not None
133
134
135
        casted_x, updated_amax = CastFP8Primitive.inner_primitive.bind(
            x, amax, scale, scale_inv, out_dtype=out_dtype
        )
136
137
138
139
140
141
142
143
144
145
146
        return casted_x, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype):
        check_valid_batch_dims(batch_dims)
        assert CastFP8Primitive.outer_primitive is not None

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

        out_bdims = x_bdim, amax_bdim
147
148
149
150
        return (
            CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype),
            out_bdims,
        )
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
170
171
172
            local_cx, local_updated_amax = CastFP8Primitive.impl(
                x, amax, scale, scale_inv, out_dtype=out_dtype
            )
173
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
174
175
176
177
178
179
180
181
182

            return local_cx, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastFP8Primitive)


183
184
185
186
187
188
189
def cast_fp8(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
190
191
192
193
    """
    Cast wrapper
    Return FP8 tensor
    """
194
195
    if not CastFP8Primitive.enabled():
        return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype)
196
    return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)