Unverified Commit 1e2c68d6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)



* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a282136c
...@@ -57,14 +57,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -57,14 +57,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
name = "te_dbias_quantize_ffi" name = "te_dbias_quantize_ffi"
multiple_results = True multiple_results = True
impl_static_args = ( impl_static_args = (
2,
3, 3,
4, 4,
5, 5,
6, 6,
7, 7,
8, 8,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer 9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -72,6 +72,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -72,6 +72,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
def abstract( def abstract(
x_aval, x_aval,
scale_aval, scale_aval,
amax_aval,
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
...@@ -95,7 +96,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -95,7 +96,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
rowwise_out_shape = (1,) rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = amax_aval
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
...@@ -168,6 +169,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -168,6 +169,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
ctx, ctx,
x, x,
scale, scale,
amax,
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
...@@ -181,13 +183,17 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -181,13 +183,17 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
te_dbias_quantize_p lowering rules te_dbias_quantize_p lowering rules
""" """
del out_dtype, scale_dtype, is_outer del out_dtype, scale_dtype, is_outer
x_aval, scale_aval = ctx.avals_in x_aval, scale_aval, amax_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == amax_aval.dtype == jnp.float32
return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)( return ffi.ffi_lowering(
BaseDBiasQuantizePrimitive.name,
operand_output_aliases={2: 4}, # donate amax buffer to updated_amax
)(
ctx, ctx,
x, x,
scale, scale,
amax,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
q_layout=q_layout, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
...@@ -198,6 +204,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -198,6 +204,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
def impl( def impl(
x, x,
scale, scale,
amax,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_layout, q_layout,
...@@ -222,6 +229,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -222,6 +229,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) = BaseDBiasQuantizePrimitive.inner_primitive.bind( ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_layout=q_layout, q_layout=q_layout,
...@@ -268,15 +276,15 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -268,15 +276,15 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del is_outer del is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert BaseDBiasQuantizePrimitive.outer_primitive is not None assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args x, scale, amax = batched_args
x_bdim, scale_bdim = batch_dims x_bdim, scale_bdim, amax_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return ( return (
BaseDBiasQuantizePrimitive.outer_primitive.bind( BaseDBiasQuantizePrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_layout=q_layout, q_layout=q_layout,
...@@ -303,7 +311,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -303,7 +311,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) amax_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
...@@ -329,10 +337,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -329,10 +337,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.dbias_sharding", desc="BaseDBiasQuantizePrimitive.dbias_sharding",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
...@@ -341,14 +347,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -341,14 +347,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
) )
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*colwise_scale_inv_spec), PartitionSpec(*colwise_scale_inv_spec),
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
) )
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
return ( return (
out_sharding, out_sharding,
...@@ -375,7 +381,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -375,7 +381,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) amax_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
...@@ -401,10 +407,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -401,10 +407,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.dbias_sharding", desc="BaseDBiasQuantizePrimitive.dbias_sharding",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
...@@ -432,7 +436,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -432,7 +436,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias_sharding, dbias_sharding,
) )
def sharded_impl(x, scale): def sharded_impl(x, scale, amax):
( (
local_x, local_x,
local_colwise_x, local_colwise_x,
...@@ -443,6 +447,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -443,6 +447,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) = BaseDBiasQuantizePrimitive.impl( ) = BaseDBiasQuantizePrimitive.impl(
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_layout=q_layout, q_layout=q_layout,
...@@ -510,7 +515,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -510,7 +515,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",)), (x_axes, ("…1",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
) )
...@@ -638,6 +643,9 @@ def _quantize_dbias_impl( ...@@ -638,6 +643,9 @@ def _quantize_dbias_impl(
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# Make sure amax is init with zero
amax = jnp.zeros((1,), jnp.float32)
# It is faster to use 1x quantization for tensor scaling # It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = ( force_1x_quantization = (
...@@ -659,6 +667,7 @@ def _quantize_dbias_impl( ...@@ -659,6 +667,7 @@ def _quantize_dbias_impl(
) = PrimitiveClass.outer_primitive.bind( ) = PrimitiveClass.outer_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value, q_layout=q_layout.value,
......
...@@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
} }
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type output_trans_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
bool is_dbias, int64_t flatten_axis) { bool is_dbias, int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
...@@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) { if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
...@@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment