quantization.py 6.47 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
from typing import Tuple

7
import jax
8
9
10
11
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
12
from jax import ffi
13
14
15
16
17
18
19
20
21
22

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
23
    jax_dtype_to_ir_dtype,
24
    is_ffi_enabled,
25
26
27
28
)
from ..sharding import all_reduce_max_along_all_axes_except_PP


29
__all__ = ["cast_fp8"]
30
31


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _jax_quantize(x, scale, q_dtype):
    """
    Quantize with scale
    """
    compute_dtype = scale.dtype
    dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype)
    scaled_x = x.astype(compute_dtype) * scale
    clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
    return clipped_scaled_x.astype(q_dtype)


def _jax_cast_fp8(inputs, scale, amax, out_dtype):
    """
    JAX native fp8 casting implementation
    """
    casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype)
    updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
    return casted_output, updated_amax


52
53
54
55
class CastFP8Primitive(BasePrimitive):
    """
    Cast Primitive
    """
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    name = "te_quantize"
    multiple_results = True
    impl_static_args = (4,)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
        """
        te_cast abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, updated_amax_aval

    @staticmethod
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
        """
        te_cast lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        if is_ffi_enabled():
            name = "te_quantize_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
                ctx, x, amax, scale, scale_inv
            )
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape

            out_types = [
                ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [x, amax, scale, scale_inv]
            operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_common_descriptor(
                ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype)
            )

            out = custom_caller(
                CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
            )
119
120
121
122
123
124
125
126
127

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype):
        """
        te_cast implementation
        """
        assert CastFP8Primitive.inner_primitive is not None
128
129
130
        casted_x, updated_amax = CastFP8Primitive.inner_primitive.bind(
            x, amax, scale, scale_inv, out_dtype=out_dtype
        )
131
132
133
134
135
136
137
138
139
140
141
        return casted_x, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype):
        check_valid_batch_dims(batch_dims)
        assert CastFP8Primitive.outer_primitive is not None

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

        out_bdims = x_bdim, amax_bdim
142
143
144
145
        return (
            CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype),
            out_bdims,
        )
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
165
166
167
            local_cx, local_updated_amax = CastFP8Primitive.impl(
                x, amax, scale, scale_inv, out_dtype=out_dtype
            )
168
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
169
170
171
172
173
174
175
176
177

            return local_cx, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastFP8Primitive)


178
179
180
181
182
183
184
def cast_fp8(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
185
186
187
188
    """
    Cast wrapper
    Return FP8 tensor
    """
189
190
    if not CastFP8Primitive.enabled():
        return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype)
191
    return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)