"vscode:/vscode.git/clone" did not exist on "aa06107cbc1cc7378c665809c1608c53070447ea"
Unverified Commit 1e2c68d6 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)



* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a282136c
......@@ -57,14 +57,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
name = "te_dbias_quantize_ffi"
multiple_results = True
impl_static_args = (
2,
3,
4,
5,
6,
7,
8,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval
inner_primitive = None
outer_primitive = None
......@@ -72,6 +72,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
def abstract(
x_aval,
scale_aval,
amax_aval,
*,
out_dtype,
scaling_mode,
......@@ -95,7 +96,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
updated_amax_aval = amax_aval
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
......@@ -168,6 +169,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
ctx,
x,
scale,
amax,
*,
out_dtype,
scaling_mode,
......@@ -181,13 +183,17 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval = ctx.avals_in
x_aval, scale_aval, amax_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
assert scale_aval.dtype == amax_aval.dtype == jnp.float32
return ffi.ffi_lowering(
BaseDBiasQuantizePrimitive.name,
operand_output_aliases={2: 4}, # donate amax buffer to updated_amax
)(
ctx,
x,
scale,
amax,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
......@@ -198,6 +204,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
def impl(
x,
scale,
amax,
out_dtype,
scaling_mode,
q_layout,
......@@ -222,6 +229,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -268,15 +276,15 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del is_outer
check_valid_batch_dims(batch_dims)
assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
x, scale, amax = batched_args
x_bdim, scale_bdim, amax_bdim = batch_dims
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return (
BaseDBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -303,7 +311,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
amax_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
......@@ -329,10 +337,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
......@@ -341,14 +347,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
return (
out_sharding,
......@@ -375,7 +381,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
amax_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
......@@ -401,10 +407,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
......@@ -432,7 +436,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias_sharding,
)
def sharded_impl(x, scale):
def sharded_impl(x, scale, amax):
(
local_x,
local_colwise_x,
......@@ -443,6 +447,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) = BaseDBiasQuantizePrimitive.impl(
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -510,7 +515,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
amax = (prefix + "amax",)
return SdyShardingRule(
(x_axes, ("…1",)),
(x_axes, ("…1",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
)
......@@ -638,6 +643,9 @@ def _quantize_dbias_impl(
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure amax is init with zero
amax = jnp.zeros((1,), jnp.float32)
# It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = (
......@@ -659,6 +667,7 @@ def _quantize_dbias_impl(
) = PrimitiveClass.outer_primitive.bind(
x,
scale,
amax,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
......
......@@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
}
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type output_trans_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
bool is_dbias, int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
......@@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......@@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment