Unverified Commit 127b6d3a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Activation/Normalization to output amax for later quantization in CurrentScaling (#2238)



* reuse amax for current scaling
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9f3e79bf
...@@ -148,7 +148,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -148,7 +148,6 @@ class ActLuPrimitive(BasePrimitive):
name = "te_act_lu_ffi" name = "te_act_lu_ffi"
multiple_results = True multiple_results = True
impl_static_args = ( impl_static_args = (
2,
3, 3,
4, 4,
5, 5,
...@@ -156,7 +155,11 @@ class ActLuPrimitive(BasePrimitive): ...@@ -156,7 +155,11 @@ class ActLuPrimitive(BasePrimitive):
7, 7,
8, 8,
9, 9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params 10,
11,
12,
13,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -164,6 +167,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -164,6 +167,7 @@ class ActLuPrimitive(BasePrimitive):
def abstract( def abstract(
x_aval, x_aval,
scale_aval, scale_aval,
amax_aval,
*, *,
out_dtype, out_dtype,
act_enum, act_enum,
...@@ -171,16 +175,23 @@ class ActLuPrimitive(BasePrimitive): ...@@ -171,16 +175,23 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
te_act_lu_p abstract te_act_lu_p abstract
""" """
del act_enum, act_params del act_enum, act_params, amax_scope, transpose_batch_sequence
assert (
not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value
), f"scaling_mode = {scaling_mode}"
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert amax_aval is None or amax_aval.dtype == jnp.float32
assert x_aval.shape[-2] == act_len, ( assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape" "activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}" f" {x_aval.shape} and act_len {act_len}"
...@@ -215,6 +226,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -215,6 +226,7 @@ class ActLuPrimitive(BasePrimitive):
ctx, ctx,
x, x,
scale, scale,
amax,
*, *,
out_dtype, out_dtype,
act_enum, act_enum,
...@@ -222,24 +234,34 @@ class ActLuPrimitive(BasePrimitive): ...@@ -222,24 +234,34 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
te_gated_act_lu_p lowering rules te_gated_act_lu_p lowering rules
""" """
del out_dtype, scale_dtype, act_len, is_outer del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
x_aval, scale_aval = ctx.avals_in x_aval, scale_aval, amax_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)( assert amax_aval.dtype == jnp.float32
out = ffi.ffi_lowering(
ActLuPrimitive.name,
operand_output_aliases={2: 4}, # donate amax buffer to updated_amax
)(
ctx, ctx,
x, x,
scale, scale,
amax,
act_enum=act_enum, act_enum=act_enum,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, is_2x=is_2x,
act_params=act_params.to_ffi_lowering_dict(), act_params=act_params.to_ffi_lowering_dict(),
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
return out return out
...@@ -247,14 +269,18 @@ class ActLuPrimitive(BasePrimitive): ...@@ -247,14 +269,18 @@ class ActLuPrimitive(BasePrimitive):
def impl( def impl(
x, x,
scale, scale,
amax,
out_dtype, out_dtype,
act_enum, act_enum,
act_len, act_len,
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
to describe implementation to describe implementation
...@@ -266,14 +292,18 @@ class ActLuPrimitive(BasePrimitive): ...@@ -266,14 +292,18 @@ class ActLuPrimitive(BasePrimitive):
ActLuPrimitive.inner_primitive.bind( ActLuPrimitive.inner_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
is_outer=False,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
) )
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -301,17 +331,19 @@ class ActLuPrimitive(BasePrimitive): ...@@ -301,17 +331,19 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
del act_len, is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert ActLuPrimitive.outer_primitive is not None assert ActLuPrimitive.outer_primitive is not None
x, scale = batched_args x, scale, amax = batched_args
x_bdim, scale_bdim = batch_dims x_bdim, scale_bdim, _ = batch_dims
amax_bdim = scale_bdim amax_bdim = scale_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
...@@ -319,12 +351,18 @@ class ActLuPrimitive(BasePrimitive): ...@@ -319,12 +351,18 @@ class ActLuPrimitive(BasePrimitive):
ActLuPrimitive.outer_primitive.bind( ActLuPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=is_outer,
), ),
out_bdims, out_bdims,
) )
...@@ -337,8 +375,11 @@ class ActLuPrimitive(BasePrimitive): ...@@ -337,8 +375,11 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -349,8 +390,11 @@ class ActLuPrimitive(BasePrimitive): ...@@ -349,8 +390,11 @@ class ActLuPrimitive(BasePrimitive):
act_enum, act_enum,
scale_dtype, scale_dtype,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
) # Unused. ) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
...@@ -402,13 +446,16 @@ class ActLuPrimitive(BasePrimitive): ...@@ -402,13 +446,16 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del result_infos, is_outer # Unused. del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
...@@ -452,26 +499,40 @@ class ActLuPrimitive(BasePrimitive): ...@@ -452,26 +499,40 @@ class ActLuPrimitive(BasePrimitive):
amax_sharding, amax_sharding,
) )
def sharded_impl(x, scale): def sharded_impl(x, scale, amax):
local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = ( (
ActLuPrimitive.impl( local_x,
x, local_colwise_x,
scale, local_scale_inv,
out_dtype=out_dtype, local_colwise_scale_inv,
act_enum=act_enum, local_updated_amax,
act_len=act_len, ) = ActLuPrimitive.impl(
scaling_mode=scaling_mode, x,
is_2x=is_2x, scale,
scale_dtype=scale_dtype, amax,
is_outer=True, out_dtype=out_dtype,
act_params=act_params, act_enum=act_enum,
) act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
) )
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(
local_updated_amax, mesh
)
elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
local_updated_amax, out_spec, transpose_batch_sequence, mesh
)
else: else:
global_updated_amax = local_amax global_updated_amax = local_updated_amax
return ( return (
local_x, local_x,
...@@ -491,13 +552,28 @@ class ActLuPrimitive(BasePrimitive): ...@@ -491,13 +552,28 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh, mesh,
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params del (
out_dtype,
act_enum,
act_len,
scale_dtype,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
result_types,
)
prefix = "ActLu_" prefix = "ActLu_"
input_shape = value_types[0].shape input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:] output_shape = input_shape[:-2] + input_shape[-1:]
...@@ -526,6 +602,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -526,6 +602,7 @@ class ActLuPrimitive(BasePrimitive):
( (
x_axes, x_axes,
("…1",), ("…1",),
amax,
), ),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
**scale_rules.factor_sizes, **scale_rules.factor_sizes,
...@@ -543,8 +620,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -543,8 +620,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi" name = "te_dact_dbias_quantize_ffi"
multiple_results = True multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -553,6 +630,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -553,6 +630,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz_aval, dz_aval,
x_aval, x_aval,
scale_aval, scale_aval,
amax_aval,
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
...@@ -561,13 +639,16 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -561,13 +639,16 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
te_dact_dbias_quantize_p abstract te_dact_dbias_quantize_p abstract
""" """
del act_enum, act_params del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype assert x_aval.dtype == dz_dtype
...@@ -576,6 +657,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -576,6 +657,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
f" {x_aval.shape} and act_len {act_len}" f" {x_aval.shape} and act_len {act_len}"
) )
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert amax_aval.dtype == jnp.float32
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused dact and quantization. Please do" "Current tensor scaling is not supported for fused dact and quantization. Please do"
...@@ -655,6 +737,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -655,6 +737,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz, dz,
x, x,
scale, scale,
amax,
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
...@@ -663,27 +746,42 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -663,27 +746,42 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
te_dact_dbias_quantize_p lowering rules te_dact_dbias_quantize_p lowering rules
""" """
del out_dtype, scale_dtype, act_len, is_outer del (
dz_aval, x_aval, scale_aval = ctx.avals_in out_dtype,
scale_dtype,
act_len,
is_outer,
amax_scope,
transpose_batch_sequence,
)
dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == amax_aval.dtype == jnp.float32
return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)( return ffi.ffi_lowering(
BaseDActLuDBiasQuantizePrimitive.name,
operand_output_aliases={3: 4}, # donate amax buffer to updated_amax
)(
ctx, ctx,
dz, dz,
x, x,
scale, scale,
amax,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, is_2x=is_2x,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=int(act_enum), act_enum=int(act_enum),
act_params=act_params.to_ffi_lowering_dict(), act_params=act_params.to_ffi_lowering_dict(),
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
@staticmethod @staticmethod
...@@ -691,6 +789,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -691,6 +789,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz, dz,
x, x,
scale, scale,
amax,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
is_2x, is_2x,
...@@ -698,8 +797,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -698,8 +797,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
te_dact_dbias_quantize_p impl te_dact_dbias_quantize_p impl
...@@ -711,6 +813,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -711,6 +813,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz, dz,
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
...@@ -718,8 +821,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -718,8 +821,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
is_outer=False,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
) )
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -747,17 +853,19 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -747,17 +853,19 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
): ):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
del is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
dz, x, scale = batched_args dz, x, scale, amax = batched_args
_, x_bdim, scale_bdim = batch_dims _, x_bdim, scale_bdim, _ = batch_dims
out_bdims = ( out_bdims = (
x_bdim, # rowwise output x_bdim, # rowwise output
...@@ -772,6 +880,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -772,6 +880,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz, dz,
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
...@@ -780,6 +889,10 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -780,6 +889,10 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=is_outer,
), ),
out_bdims, out_bdims,
) )
...@@ -793,14 +906,18 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -793,14 +906,18 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del out_dtype, result_infos, act_enum, act_params del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling
del scale_dtype, act_len, is_outer del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2]) scale_spec = get_padded_spec(arg_infos[2])
...@@ -869,8 +986,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -869,8 +986,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -937,12 +1057,13 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -937,12 +1057,13 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dbias_sharding, dbias_sharding,
) )
def sharded_impl(dz, x, scale): def sharded_impl(dz, x, scale, amax):
(out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = ( (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = (
BaseDActLuDBiasQuantizePrimitive.impl( BaseDActLuDBiasQuantizePrimitive.impl(
dz, dz,
x, x,
scale, scale,
amax,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
...@@ -950,8 +1071,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -950,8 +1071,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
is_outer=True,
act_params=act_params, act_params=act_params,
output_amax_when_no_scaling=output_amax_when_no_scaling,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
is_outer=True,
) )
) )
if is_dbias: if is_dbias:
...@@ -960,9 +1084,15 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -960,9 +1084,15 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
global_dbias = local_dbias global_dbias = local_dbias
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(
local_updated_amax, mesh
)
elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
local_updated_amax, x_spec, transpose_batch_sequence, mesh
)
else: else:
global_updated_amax = local_amax global_updated_amax = local_updated_amax
return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
...@@ -977,14 +1107,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -977,14 +1107,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
is_outer,
act_params, act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh, mesh,
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params del (
out_dtype,
scale_dtype,
act_enum,
act_len,
act_params,
is_outer,
output_amax_when_no_scaling,
mesh,
result_types,
amax_scope,
transpose_batch_sequence,
)
prefix = "DActLuDBias_" prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
...@@ -1006,7 +1152,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -1006,7 +1152,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(dz_axes, x_axes, ("…2",)), (dz_axes, x_axes, ("…2",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes, **scale_rules.factor_sizes,
) )
...@@ -1092,6 +1238,8 @@ def act_lu( ...@@ -1092,6 +1238,8 @@ def act_lu(
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None, act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -1108,6 +1256,8 @@ def act_lu( ...@@ -1108,6 +1256,8 @@ def act_lu(
If quantizer is provided: If quantizer is provided:
A ScaledTensor containing the quantized activated input. A ScaledTensor containing the quantized activated input.
""" """
# TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl()
# Do the same with dact_dbias_quantize() and layernorm_fwd()
act_type_id = ActivationEnum[activation_type].value act_type_id = ActivationEnum[activation_type].value
act_len = len(activation_type) act_len = len(activation_type)
assert x.shape[-2] == act_len, ( assert x.shape[-2] == act_len, (
...@@ -1123,30 +1273,44 @@ def act_lu( ...@@ -1123,30 +1273,44 @@ def act_lu(
return _jax_act_lu(x, activation_type, quantizer, act_params) return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support 2x quantization for DelayedScaling yet # TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war( war_output = try_apply_delayed_scaling_2x_war(
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params f=act_lu,
x=x,
activation_type=activation_type,
quantizer=quantizer,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
if war_output is not None: if war_output is not None:
return war_output return war_output
scale = jnp.empty((1,), jnp.float32) scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-2], x.shape[-1]) output_shape = (*x.shape[:-2], x.shape[-1])
amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,)
if quantizer is None: if quantizer is None:
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=x.dtype, out_dtype=x.dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
is_outer=True,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
# TODO(Phuong): ScaledTensorFactory to create NoScaledTensor
out = NoScaleTensor( out = NoScaleTensor(
data=out, data=out,
amax=None, amax=updated_amax if output_amax_when_no_scaling else None,
) )
return out return out
...@@ -1157,6 +1321,9 @@ def act_lu( ...@@ -1157,6 +1321,9 @@ def act_lu(
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
out, out,
...@@ -1164,6 +1331,7 @@ def act_lu( ...@@ -1164,6 +1331,7 @@ def act_lu(
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype, dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
return out return out
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
...@@ -1178,14 +1346,18 @@ def act_lu( ...@@ -1178,14 +1346,18 @@ def act_lu(
) = ActLuPrimitive.outer_primitive.bind( ) = ActLuPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_outer=True,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
...@@ -1209,6 +1381,9 @@ def quantize_dact_dbias( ...@@ -1209,6 +1381,9 @@ def quantize_dact_dbias(
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None, act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1232,7 +1407,8 @@ def quantize_dact_dbias( ...@@ -1232,7 +1407,8 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}" f" {x.shape} and act_len {act_len}"
) )
scale = jnp.empty((), jnp.float32) scale = jnp.empty((1,), jnp.float32)
amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,)
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled() or ( if not PrimitiveClass.enabled() or (
...@@ -1240,10 +1416,11 @@ def quantize_dact_dbias( ...@@ -1240,10 +1416,11 @@ def quantize_dact_dbias(
): ):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
if quantizer is None: if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind(
dz, dz,
x, x,
scale, scale,
amax,
# outputs float32 for dbias accumulation # outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype), out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset # default value for no scaling, TE/common ignore this value when scale is unset
...@@ -1253,8 +1430,11 @@ def quantize_dact_dbias( ...@@ -1253,8 +1430,11 @@ def quantize_dact_dbias(
is_dbias=False, is_dbias=False,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
is_outer=True,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
) )
output = output.astype(x.dtype) output = output.astype(x.dtype)
dbias = None dbias = None
...@@ -1263,7 +1443,7 @@ def quantize_dact_dbias( ...@@ -1263,7 +1443,7 @@ def quantize_dact_dbias(
output = NoScaleTensor( output = NoScaleTensor(
data=output, data=output,
amax=None, amax=updated_amax if output_amax_when_no_scaling else None,
) )
return output, dbias return output, dbias
...@@ -1275,9 +1455,18 @@ def quantize_dact_dbias( ...@@ -1275,9 +1455,18 @@ def quantize_dact_dbias(
activation_type, activation_type,
quantizer=None, quantizer=None,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
return _quantize_dbias_impl( return _quantize_dbias_impl(
out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 out.data,
quantizer,
is_dbias=True,
dq_dtype=x.dtype,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
is_gated = act_len == 2 is_gated = act_len == 2
...@@ -1292,6 +1481,9 @@ def quantize_dact_dbias( ...@@ -1292,6 +1481,9 @@ def quantize_dact_dbias(
quantizer=quantizer, quantizer=quantizer,
flatten_axis=-2, flatten_axis=-2,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
if war_output is not None: if war_output is not None:
return war_output return war_output
...@@ -1304,9 +1496,18 @@ def quantize_dact_dbias( ...@@ -1304,9 +1496,18 @@ def quantize_dact_dbias(
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
) )
out, dbias = _quantize_dbias_impl( out, dbias = _quantize_dbias_impl(
out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 out,
is_dbias=is_dbias,
quantizer=quantizer,
dq_dtype=x.dtype,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
return out, dbias return out, dbias
...@@ -1320,9 +1521,17 @@ def quantize_dact_dbias( ...@@ -1320,9 +1521,17 @@ def quantize_dact_dbias(
x.astype(jnp.float32), x.astype(jnp.float32),
activation_type=activation_type, activation_type=activation_type,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
out, dbias = _quantize_dbias_impl( out, dbias = _quantize_dbias_impl(
dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 dgated,
quantizer,
is_dbias=True,
dq_dtype=x.dtype,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
return out, dbias return out, dbias
...@@ -1337,6 +1546,7 @@ def quantize_dact_dbias( ...@@ -1337,6 +1546,7 @@ def quantize_dact_dbias(
dz, dz,
x, x,
scale, scale,
amax,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
...@@ -1344,8 +1554,11 @@ def quantize_dact_dbias( ...@@ -1344,8 +1554,11 @@ def quantize_dact_dbias(
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
is_outer=True,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
) )
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
...@@ -1375,6 +1588,9 @@ def dact_lu( ...@@ -1375,6 +1588,9 @@ def dact_lu(
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None, act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1396,5 +1612,8 @@ def dact_lu( ...@@ -1396,5 +1612,8 @@ def dact_lu(
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
act_params=act_params, act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
return output return output
...@@ -92,7 +92,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -92,7 +92,7 @@ class NormFwdPrimitive(BasePrimitive):
name = "te_norm_forward_ffi" name = "te_norm_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11) impl_static_args = (5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -100,6 +100,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -100,6 +100,7 @@ class NormFwdPrimitive(BasePrimitive):
def abstract( def abstract(
x_aval, x_aval,
scale_aval, scale_aval,
amax_aval,
gamma_aval, gamma_aval,
beta_aval, beta_aval,
*, *,
...@@ -110,15 +111,27 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -110,15 +111,27 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
): ):
""" """
LayerNorm fwd inner primitive abstract LayerNorm fwd inner primitive abstract
""" """
del amax_scope, transpose_batch_sequence
assert not output_amax_when_no_scaling or (
scaling_mode == ScalingMode.NO_SCALING.value
and not is_norm_fwd_cudnn_enabled(scaling_mode)
), (
f"scaling_mode = {scaling_mode},"
f" use_cudnn_norm_fwd={is_norm_fwd_cudnn_enabled(scaling_mode)}"
)
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert amax_aval is None or amax_aval.dtype == jnp.float32
assert ( assert (
scaling_mode != ScalingMode.MXFP8_1D_SCALING.value scaling_mode != ScalingMode.MXFP8_1D_SCALING.value
...@@ -220,6 +233,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -220,6 +233,7 @@ class NormFwdPrimitive(BasePrimitive):
ctx, ctx,
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
*, *,
...@@ -230,16 +244,20 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -230,16 +244,20 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
): ):
""" """
LayerNorm fwd lowering rules LayerNorm fwd lowering rules
""" """
del out_dtype, scale_dtype, is_outer del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence
x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert amax_aval is None or amax_aval.dtype == jnp.float32
g_type = ir.RankedTensorType(gamma.type) g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape g_shape = g_type.shape
...@@ -251,10 +269,14 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -251,10 +269,14 @@ class NormFwdPrimitive(BasePrimitive):
assert g_shape == b_shape assert g_shape == b_shape
sm_margin = get_forward_sm_margin() sm_margin = get_forward_sm_margin()
return ffi.ffi_lowering(NormFwdPrimitive.name)( return ffi.ffi_lowering(
NormFwdPrimitive.name,
operand_output_aliases={2: 4}, # amax <-> updated_amax
)(
ctx, ctx,
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=norm_type.value, norm_type=norm_type.value,
...@@ -263,12 +285,14 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -263,12 +285,14 @@ class NormFwdPrimitive(BasePrimitive):
sm_margin=sm_margin, sm_margin=sm_margin,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
is_2x=is_2x, is_2x=is_2x,
output_amax_when_no_scaling=output_amax_when_no_scaling,
) )
@staticmethod @staticmethod
def impl( def impl(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type, norm_type,
...@@ -278,6 +302,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -278,6 +302,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
): ):
""" """
...@@ -297,6 +324,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -297,6 +324,7 @@ class NormFwdPrimitive(BasePrimitive):
) = NormFwdPrimitive.inner_primitive.bind( ) = NormFwdPrimitive.inner_primitive.bind(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=norm_type, norm_type=norm_type,
...@@ -306,6 +334,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -306,6 +334,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False, is_outer=False,
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -341,16 +372,18 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -341,16 +372,18 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
): ):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
del is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert NormFwdPrimitive.outer_primitive is not None assert NormFwdPrimitive.outer_primitive is not None
x, scale, gamma, beta = batched_args x, scale, amax, gamma, beta = batched_args
x_bdim, scale_bdim, _, _ = batch_dims x_bdim, scale_bdim, _, _, _ = batch_dims
out_bdims = ( out_bdims = (
x_bdim, # rowwise output x_bdim, # rowwise output
...@@ -363,8 +396,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -363,8 +396,9 @@ class NormFwdPrimitive(BasePrimitive):
) )
return ( return (
NormFwdPrimitive.outer_primitive.bind( NormFwdPrimitive.outer_primitive.bind(
scale,
x, x,
scale,
amax,
gamma, gamma,
beta, beta,
norm_type=norm_type, norm_type=norm_type,
...@@ -374,6 +408,10 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -374,6 +408,10 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=is_outer,
), ),
out_bdims, out_bdims,
) )
...@@ -387,15 +425,19 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -387,15 +425,19 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del zero_centered_gamma, epsilon, out_dtype, result_infos del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, is_outer del scale_dtype, is_outer, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
amax_spec = get_padded_spec(arg_infos[2])
out_spec = (*x_spec[:-1], None) out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
...@@ -415,9 +457,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -415,9 +457,9 @@ class NormFwdPrimitive(BasePrimitive):
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,) scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec scale_inv_spec = out_spec
...@@ -445,6 +487,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -445,6 +487,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
...@@ -453,8 +498,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -453,8 +498,9 @@ class NormFwdPrimitive(BasePrimitive):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
g_spec = get_padded_spec(arg_infos[2]) amax_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3]) g_spec = get_padded_spec(arg_infos[3])
b_spec = get_padded_spec(arg_infos[4])
out_spec = (*x_spec[:-1], None) out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None: if x_spec[-1] is not None:
...@@ -485,9 +531,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -485,9 +531,9 @@ class NormFwdPrimitive(BasePrimitive):
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,) scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec scale_inv_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec scale_inv_spec = out_spec
...@@ -499,10 +545,10 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -499,10 +545,10 @@ class NormFwdPrimitive(BasePrimitive):
arg_shardings = list(arg_i.sharding for arg_i in arg_infos) arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Enforce no sharding of hidden dim for x, gamma and beta # Enforce no sharding of hidden dim for x, gamma and beta
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x") arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x")
arg_shardings[2] = NamedSharding( arg_shardings[3] = NamedSharding(
mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma" mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma"
) )
arg_shardings[3] = NamedSharding( arg_shardings[4] = NamedSharding(
mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta" mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta"
) )
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
...@@ -516,19 +562,20 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -516,19 +562,20 @@ class NormFwdPrimitive(BasePrimitive):
rsigma_sharding, rsigma_sharding,
) )
def sharded_impl(x, scale, gamma, beta): def sharded_impl(x, scale, amax, gamma, beta):
# expect tp and dp giving same shape, or tp being same shape as global # expect tp and dp giving same shape, or tp being same shape as global
( (
local_x, local_x,
local_colwise_x, local_colwise_x,
local_scale_inv, local_scale_inv,
local_colwise_scale_inv, local_colwise_scale_inv,
local_amax, local_updated_amax,
local_mu, local_mu,
local_rsigma, local_rsigma,
) = NormFwdPrimitive.impl( ) = NormFwdPrimitive.impl(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=norm_type, norm_type=norm_type,
...@@ -538,12 +585,21 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -538,12 +585,21 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True, is_outer=True,
) )
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(
local_updated_amax, mesh
)
elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
local_updated_amax, x_spec, transpose_batch_sequence, mesh
)
else: else:
global_updated_amax = local_amax global_updated_amax = local_updated_amax
return ( return (
local_x, local_x,
...@@ -566,6 +622,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -566,6 +622,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
mesh, mesh,
value_types, value_types,
...@@ -576,6 +635,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -576,6 +635,9 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scale_dtype, scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer, is_outer,
mesh, mesh,
result_types, result_types,
...@@ -594,7 +656,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -594,7 +656,7 @@ class NormFwdPrimitive(BasePrimitive):
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)), (x_axes, ("…1",), amax, ("…2",), ("…3",)),
( (
out, out,
colwise_out, colwise_out,
...@@ -882,6 +944,8 @@ def layernorm_fwd( ...@@ -882,6 +944,8 @@ def layernorm_fwd(
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization. """Layer normalization forward pass with optional quantization.
...@@ -896,6 +960,7 @@ def layernorm_fwd( ...@@ -896,6 +960,7 @@ def layernorm_fwd(
epsilon: Small constant for numerical stability. epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -918,10 +983,12 @@ def layernorm_fwd( ...@@ -918,10 +983,12 @@ def layernorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer) if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32) else jnp.ones((1,), dtype=jnp.float32)
) )
amax = jnp.zeros((1,), dtype=jnp.float32)
if quantizer is None: if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( output, _, _, _, updated_amax, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=NVTE_Norm_Type.LayerNorm, norm_type=NVTE_Norm_Type.LayerNorm,
...@@ -931,18 +998,37 @@ def layernorm_fwd( ...@@ -931,18 +998,37 @@ def layernorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=False,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True, is_outer=True,
) )
return NoScaleTensor(data=output, amax=None), mu, rsigma # cuDNN does not support amax output for non quantized output
updated_amax = (
updated_amax
if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING)
else None
)
return NoScaleTensor(data=output, amax=updated_amax), mu, rsigma
if ( if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
): ):
out, mu, rsigma = layernorm_fwd( out, mu, rsigma = layernorm_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence
) )
out, _ = _quantize_dbias_impl(out, quantizer)
return out, mu, rsigma return out, mu, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
...@@ -954,6 +1040,9 @@ def layernorm_fwd( ...@@ -954,6 +1040,9 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
quantizer=None, quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
out, out,
...@@ -961,6 +1050,7 @@ def layernorm_fwd( ...@@ -961,6 +1050,7 @@ def layernorm_fwd(
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype, dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
return out, mu, rsigma return out, mu, rsigma
...@@ -979,6 +1069,7 @@ def layernorm_fwd( ...@@ -979,6 +1069,7 @@ def layernorm_fwd(
) = NormFwdPrimitive.outer_primitive.bind( ) = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=NVTE_Norm_Type.LayerNorm, norm_type=NVTE_Norm_Type.LayerNorm,
...@@ -988,6 +1079,9 @@ def layernorm_fwd( ...@@ -988,6 +1079,9 @@ def layernorm_fwd(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True, is_outer=True,
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
...@@ -1091,7 +1185,9 @@ def rmsnorm_fwd( ...@@ -1091,7 +1185,9 @@ def rmsnorm_fwd(
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.TPSP,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
"""Root mean square normalization forward pass with optional quantization. """Root mean square normalization forward pass with optional quantization.
...@@ -1104,6 +1200,7 @@ def rmsnorm_fwd( ...@@ -1104,6 +1200,7 @@ def rmsnorm_fwd(
epsilon: Small constant for numerical stability. epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1127,12 +1224,14 @@ def rmsnorm_fwd( ...@@ -1127,12 +1224,14 @@ def rmsnorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer) if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32) else jnp.ones((1,), dtype=jnp.float32)
) )
amax = jnp.zeros((1,), dtype=jnp.float32)
beta = jnp.ones((1,), dtype=jnp.float32) beta = jnp.ones((1,), dtype=jnp.float32)
if quantizer is None: if quantizer is None:
output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind( output, _, _, _, updated_amax, _, rsigma = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=NVTE_Norm_Type.RMSNorm, norm_type=NVTE_Norm_Type.RMSNorm,
...@@ -1142,16 +1241,39 @@ def rmsnorm_fwd( ...@@ -1142,16 +1241,39 @@ def rmsnorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True, is_outer=True,
) )
return NoScaleTensor(data=output, amax=None), rsigma # cuDNN does not support amax output for non quantized output
updated_amax = (
updated_amax
if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING)
else None
)
return NoScaleTensor(data=output, amax=updated_amax), rsigma
if ( if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
): ):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) out, rsigma = rmsnorm_fwd(
out, _ = _quantize_dbias_impl(out.data, quantizer) x,
gamma,
zero_centered_gamma,
epsilon,
quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out.data,
quantizer,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, rsigma return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
...@@ -1162,13 +1284,17 @@ def rmsnorm_fwd( ...@@ -1162,13 +1284,17 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
quantizer=None, quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
) )
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
out.data, out,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=x.dtype, dq_dtype=x.dtype,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
return out, rsigma return out, rsigma
...@@ -1187,6 +1313,7 @@ def rmsnorm_fwd( ...@@ -1187,6 +1313,7 @@ def rmsnorm_fwd(
) = NormFwdPrimitive.outer_primitive.bind( ) = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale, scale,
amax,
gamma, gamma,
beta, beta,
norm_type=NVTE_Norm_Type.RMSNorm, norm_type=NVTE_Norm_Type.RMSNorm,
...@@ -1196,6 +1323,9 @@ def rmsnorm_fwd( ...@@ -1196,6 +1323,9 @@ def rmsnorm_fwd(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True, is_outer=True,
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
...@@ -1294,6 +1424,7 @@ def normalization_fwd( ...@@ -1294,6 +1424,7 @@ def normalization_fwd(
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1311,6 +1442,7 @@ def normalization_fwd( ...@@ -1311,6 +1442,7 @@ def normalization_fwd(
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1336,6 +1468,7 @@ def normalization_fwd( ...@@ -1336,6 +1468,7 @@ def normalization_fwd(
epsilon, epsilon,
quantizer, quantizer,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
elif norm_type == "rmsnorm": elif norm_type == "rmsnorm":
assert ( assert (
...@@ -1348,6 +1481,7 @@ def normalization_fwd( ...@@ -1348,6 +1481,7 @@ def normalization_fwd(
epsilon, epsilon,
quantizer, quantizer,
amax_scope=amax_scope, amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
) )
mu = None mu = None
else: else:
......
...@@ -543,6 +543,18 @@ class AmaxScope(Enum): ...@@ -543,6 +543,18 @@ class AmaxScope(Enum):
TPSP = 2 TPSP = 2
FSDP = 3 FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive): class AmaxCalculationPrimitive(BasePrimitive):
""" """
...@@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
impl_static_args = ( impl_static_args = (
1, 1,
2, 2,
) # amax_scope, batch_sequence_transpose ) # amax_scope, transpose_batch_sequence
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -563,12 +575,12 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -563,12 +575,12 @@ class AmaxCalculationPrimitive(BasePrimitive):
x_aval, x_aval,
*, *,
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
): ):
""" """
amax calcuation abstract amax calcuation abstract
""" """
del amax_scope, batch_sequence_transpose del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -580,19 +592,19 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -580,19 +592,19 @@ class AmaxCalculationPrimitive(BasePrimitive):
def impl( def impl(
x, x,
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
): ):
""" """
amax calcuation implementation amax calcuation implementation
""" """
del amax_scope, batch_sequence_transpose del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax return amax
@staticmethod @staticmethod
def infer_sharding_from_operands( def infer_sharding_from_operands(
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -600,7 +612,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -600,7 +612,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
""" """
amax calcuation infer_sharding_from_operands amax calcuation infer_sharding_from_operands
""" """
del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused. del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(None), PartitionSpec(None),
...@@ -611,7 +623,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -611,7 +623,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
@staticmethod @staticmethod
def partition( def partition(
amax_scope, amax_scope,
batch_sequence_transpose, transpose_batch_sequence,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -631,16 +643,11 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -631,16 +643,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
amax = AmaxCalculationPrimitive.impl( amax = AmaxCalculationPrimitive.impl(
x, x,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
) )
gmesh = global_mesh_resource()
sequence_dim = 0 if batch_sequence_transpose else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if amax_scope is AmaxScope.FSDP:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax return amax
...@@ -648,11 +655,11 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -648,11 +655,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
return mesh, sharded_impl, amax_sharding, arg_shardings return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod @staticmethod
def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types): def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
""" """
amax calcuation shardy_sharding_rule amax calcuation shardy_sharding_rule
""" """
del amax_scope, batch_sequence_transpose, mesh, result_types del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal" prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",) output_spec = (f"{prefix}_amax",)
...@@ -709,7 +716,7 @@ def _quantize_dbias_impl( ...@@ -709,7 +716,7 @@ def _quantize_dbias_impl(
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -755,12 +762,12 @@ def _quantize_dbias_impl( ...@@ -755,12 +762,12 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
scale = jnp.empty((), jnp.float32) scale = jnp.empty((1,), jnp.float32)
amax = None amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale. # Globally reduce amax across all devices for current scaling so we have a single global scale.
...@@ -771,7 +778,7 @@ def _quantize_dbias_impl( ...@@ -771,7 +778,7 @@ def _quantize_dbias_impl(
amax = AmaxCalculationPrimitive.outer_primitive.bind( amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data, x.data,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
...@@ -845,7 +852,7 @@ def quantize( ...@@ -845,7 +852,7 @@ def quantize(
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -866,7 +873,7 @@ def quantize( ...@@ -866,7 +873,7 @@ def quantize(
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
return out return out
...@@ -877,7 +884,7 @@ def quantize_dbias( ...@@ -877,7 +884,7 @@ def quantize_dbias(
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -904,7 +911,7 @@ def quantize_dbias( ...@@ -904,7 +911,7 @@ def quantize_dbias(
is_dbias=is_dbias, is_dbias=is_dbias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope, amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
......
...@@ -15,13 +15,14 @@ namespace transformer_engine { ...@@ -15,13 +15,14 @@ namespace transformer_engine {
namespace jax { namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params) { bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
...@@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
auto m = product(input_dims, 0, input_dims.size() - 2); auto m = product(input_dims, 0, input_dims.size() - 2);
...@@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m}; auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK( NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
...@@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1}); convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
...@@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"), .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) { ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
act_enum, scaling_mode, is_2x_int, act_params); updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params,
output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
...@@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, ...@@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params")); .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
...@@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type amax_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, Result_Type dbias_buf, Result_Type workspace_buf,
int64_t act_enum, bool is_2x, bool is_dbias, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
ActivationConfig act_params) { bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS // parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto act_type = static_cast<NVTE_Activation_Type>(act_enum); auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
...@@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(), scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1}); convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
...@@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input .Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
...@@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"), .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI( Error_Type DActLuDBiasQuantizeInitializeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
bool is_dbias, ActivationConfig act_params) { int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf, act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params,
output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
...@@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, ...@@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Arg<Buffer_Type>() // input .Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input .Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output .Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise .Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum") .Attr<int64_t>("act_enum")
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias") .Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params")); .Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
...@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si ...@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
} }
Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
int norm_type, bool zero_centered_gamma, double epsilon, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
...@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data();
auto *mu = mu_buf->untyped_data(); auto *mu = mu_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data(); auto *workspace = wkspace_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type); auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x); auto _is_2x = static_cast<bool>(is_2x);
...@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK( NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
...@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc ...@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
} }
if (_is_2x) { if (_is_2x) {
...@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x .Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma .Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta .Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output .Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv .Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu .Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma .Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
...@@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"), .Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin, bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode, bool is_2x,
return wrapInStreamCapture( bool output_amax_when_no_scaling) {
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
} }
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
...@@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ ...@@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x .Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale .Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma .Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta .Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output .Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv .Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv .Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu .Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma .Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
...@@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ ...@@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon") .Attr<double>("epsilon")
.Attr<int64_t>("sm_margin") .Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")); .Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type, DType w_dtype, NVTE_Norm_Type norm_type,
......
...@@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) { if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax == updated_amax && amax != nullptr,
"amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv( output_tensor.set_rowwise_scale_inv(
......
...@@ -63,7 +63,7 @@ def dense( ...@@ -63,7 +63,7 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None,
...@@ -81,7 +81,7 @@ def dense( ...@@ -81,7 +81,7 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output output_axes: Logical axes for sharding the output
...@@ -91,8 +91,8 @@ def dense( ...@@ -91,8 +91,8 @@ def dense(
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
if batch_sequence_transpose: if transpose_batch_sequence:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!") warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype input_dtype = x.dtype
...@@ -103,7 +103,7 @@ def dense( ...@@ -103,7 +103,7 @@ def dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -119,7 +119,7 @@ def _dense( ...@@ -119,7 +119,7 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -136,7 +136,7 @@ def _dense( ...@@ -136,7 +136,7 @@ def _dense(
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
...@@ -151,7 +151,7 @@ def _dense( ...@@ -151,7 +151,7 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -166,7 +166,7 @@ def _dense_fwd_rule( ...@@ -166,7 +166,7 @@ def _dense_fwd_rule(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -197,7 +197,7 @@ def _dense_fwd_rule( ...@@ -197,7 +197,7 @@ def _dense_fwd_rule(
flatten_axis=flatten_axis_x, flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
...@@ -215,7 +215,7 @@ def _dense_fwd_rule( ...@@ -215,7 +215,7 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward, collective_op=collective_op_set.forward,
...@@ -240,7 +240,7 @@ def _dense_fwd_rule( ...@@ -240,7 +240,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, contracting_dims,
batch_sequence_transpose, transpose_batch_sequence,
input_axes, input_axes,
kernel_axes, kernel_axes,
output_axes, output_axes,
...@@ -274,7 +274,7 @@ def _dense_bwd_rule( ...@@ -274,7 +274,7 @@ def _dense_bwd_rule(
flatten_axis=flatten_axis_k, flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
# GEMM NT # GEMM NT
...@@ -291,7 +291,7 @@ def _dense_bwd_rule( ...@@ -291,7 +291,7 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set.backward, collective_op=collective_op_set.backward,
) )
...@@ -305,7 +305,7 @@ def _dense_bwd_rule( ...@@ -305,7 +305,7 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......
...@@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -512,6 +515,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -512,6 +515,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes, input_axes=self.input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -768,6 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -768,6 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
...@@ -775,6 +783,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -775,6 +783,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
y, y,
kernel, kernel,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
transpose_batch_sequence=self.transpose_batch_sequence,
input_axes=self.dot_input_axes, input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
...@@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1" ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2" ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -1160,6 +1172,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1160,6 +1172,7 @@ class LayerNormMLP(TransformerEngineBase):
activation_type=normalized_acts, activation_type=normalized_acts,
activation_params=self.activation_params, activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
transpose_batch_sequence=self.transpose_batch_sequence,
) )
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
...@@ -1178,6 +1191,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1178,6 +1191,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
...@@ -1188,6 +1202,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1188,6 +1202,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_1_input_axes, input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -1260,6 +1275,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1260,6 +1275,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_2_input_axes, input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2, kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set, quantizer_set=ffn2_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
......
...@@ -1207,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1207,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="qkv", name="qkv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_q) )(inputs_q)
...@@ -1234,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1234,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query", name="query",
)(inputs_q) )(inputs_q)
...@@ -1252,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1252,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj, enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
transpose_batch_sequence=self.transpose_batch_sequence,
name="kv", name="kv",
dtype=self.dtype, dtype=self.dtype,
)(inputs_kv) )(inputs_kv)
...@@ -1292,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1292,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query", name="query",
)(inputs_q) )(inputs_q)
...@@ -2070,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2070,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
transpose_batch_sequence=self.transpose_batch_sequence,
name="mlp", name="mlp",
)(mlp_input, deterministic=deterministic) )(mlp_input, deterministic=deterministic)
......
...@@ -16,6 +16,7 @@ import jax ...@@ -16,6 +16,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import ( from .quantize import (
QuantizerSet, QuantizerSet,
...@@ -35,6 +36,7 @@ def layernorm_dense( ...@@ -35,6 +36,7 @@ def layernorm_dense(
norm_type: str = "layernorm", norm_type: str = "layernorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
transpose_batch_sequence: bool = False,
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
...@@ -55,6 +57,7 @@ def layernorm_dense( ...@@ -55,6 +57,7 @@ def layernorm_dense(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
...@@ -83,6 +86,7 @@ def layernorm_dense( ...@@ -83,6 +86,7 @@ def layernorm_dense(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -100,6 +104,7 @@ def layernorm_dense( ...@@ -100,6 +104,7 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -111,6 +116,7 @@ def _layernorm_dense( ...@@ -111,6 +116,7 @@ def _layernorm_dense(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
transpose_batch_sequence: bool,
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
...@@ -131,6 +137,7 @@ def _layernorm_dense( ...@@ -131,6 +137,7 @@ def _layernorm_dense(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
...@@ -147,6 +154,7 @@ def _layernorm_dense( ...@@ -147,6 +154,7 @@ def _layernorm_dense(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule( ...@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule( ...@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=quantizer_set.x, quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
...@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule( ...@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule(
kernel, kernel,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel, quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
...@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule( ...@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule( ...@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
transpose_batch_sequence,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
...@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule( ...@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias, is_dbias=use_bias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad, quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule( ...@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule( ...@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -41,7 +41,7 @@ def layernorm_mlp( ...@@ -41,7 +41,7 @@ def layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
batch_sequence_transpose: bool = False, transpose_batch_sequence: bool = False,
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
...@@ -78,7 +78,7 @@ def layernorm_mlp( ...@@ -78,7 +78,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
...@@ -130,7 +130,7 @@ def layernorm_mlp( ...@@ -130,7 +130,7 @@ def layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -158,7 +158,7 @@ def _layernorm_mlp( ...@@ -158,7 +158,7 @@ def _layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
batch_sequence_transpose: bool, transpose_batch_sequence: bool,
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
...@@ -188,7 +188,7 @@ def _layernorm_mlp( ...@@ -188,7 +188,7 @@ def _layernorm_mlp(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding
...@@ -214,7 +214,7 @@ def _layernorm_mlp( ...@@ -214,7 +214,7 @@ def _layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule( ...@@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize( casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# NN GEMM # NN GEMM
...@@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward, collective_op=collective_op_set_1.forward,
...@@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule(
if activation_params if activation_params
else None else None
), ),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
...@@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule(
kernel_2, kernel_2,
quantizer=ffn2_quantizer_set.kernel, quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP, amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# NN GEMM # NN GEMM
...@@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward, collective_op=collective_op_set_2.forward,
...@@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose, transpose_batch_sequence,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule(
is_dbias=use_bias_2, is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad, quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP, amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_2.backward, collective_op=collective_op_set_2.backward,
) )
...@@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule(
if activation_params if activation_params
else None else None
), ),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_1.backward, collective_op=collective_op_set_1.backward,
) )
...@@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose, transpose_batch_sequence=transpose_batch_sequence,
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment