Unverified Commit 127b6d3a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Activation/Normalization to output amax for later quantization in CurrentScaling (#2238)



* reuse amax for current scaling
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 9f3e79bf
......@@ -148,7 +148,6 @@ class ActLuPrimitive(BasePrimitive):
name = "te_act_lu_ffi"
multiple_results = True
impl_static_args = (
2,
3,
4,
5,
......@@ -156,7 +155,11 @@ class ActLuPrimitive(BasePrimitive):
7,
8,
9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params
10,
11,
12,
13,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
inner_primitive = None
outer_primitive = None
......@@ -164,6 +167,7 @@ class ActLuPrimitive(BasePrimitive):
def abstract(
x_aval,
scale_aval,
amax_aval,
*,
out_dtype,
act_enum,
......@@ -171,16 +175,23 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
te_act_lu_p abstract
"""
del act_enum, act_params
del act_enum, act_params, amax_scope, transpose_batch_sequence
assert (
not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value
), f"scaling_mode = {scaling_mode}"
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert amax_aval is None or amax_aval.dtype == jnp.float32
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
......@@ -215,6 +226,7 @@ class ActLuPrimitive(BasePrimitive):
ctx,
x,
scale,
amax,
*,
out_dtype,
act_enum,
......@@ -222,24 +234,34 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
te_gated_act_lu_p lowering rules
"""
del out_dtype, scale_dtype, act_len, is_outer
x_aval, scale_aval = ctx.avals_in
del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
x_aval, scale_aval, amax_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)(
assert amax_aval.dtype == jnp.float32
out = ffi.ffi_lowering(
ActLuPrimitive.name,
operand_output_aliases={2: 4}, # donate amax buffer to updated_amax
)(
ctx,
x,
scale,
amax,
act_enum=act_enum,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
act_params=act_params.to_ffi_lowering_dict(),
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
return out
......@@ -247,14 +269,18 @@ class ActLuPrimitive(BasePrimitive):
def impl(
x,
scale,
amax,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
to describe implementation
......@@ -266,14 +292,18 @@ class ActLuPrimitive(BasePrimitive):
ActLuPrimitive.inner_primitive.bind(
x,
scale,
amax,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
is_outer=False,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -301,17 +331,19 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
to describe batch rules for vmap
"""
del act_len, is_outer
check_valid_batch_dims(batch_dims)
assert ActLuPrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
x, scale, amax = batched_args
x_bdim, scale_bdim, _ = batch_dims
amax_bdim = scale_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
......@@ -319,12 +351,18 @@ class ActLuPrimitive(BasePrimitive):
ActLuPrimitive.outer_primitive.bind(
x,
scale,
amax,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=is_outer,
),
out_bdims,
)
......@@ -337,8 +375,11 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
arg_infos,
result_infos,
......@@ -349,8 +390,11 @@ class ActLuPrimitive(BasePrimitive):
act_enum,
scale_dtype,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
......@@ -402,13 +446,16 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer # Unused.
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
......@@ -452,26 +499,40 @@ class ActLuPrimitive(BasePrimitive):
amax_sharding,
)
def sharded_impl(x, scale):
local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
ActLuPrimitive.impl(
def sharded_impl(x, scale, amax):
(
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
local_updated_amax,
) = ActLuPrimitive.impl(
x,
scale,
amax,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
is_outer=True,
act_params=act_params,
)
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(
local_updated_amax, mesh
)
elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
local_updated_amax, out_spec, transpose_batch_sequence, mesh
)
else:
global_updated_amax = local_amax
global_updated_amax = local_updated_amax
return (
local_x,
......@@ -491,13 +552,28 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params
del (
out_dtype,
act_enum,
act_len,
scale_dtype,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
result_types,
)
prefix = "ActLu_"
input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:]
......@@ -526,6 +602,7 @@ class ActLuPrimitive(BasePrimitive):
(
x_axes,
("…1",),
amax,
),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
......@@ -543,8 +620,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi"
multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
......@@ -553,6 +630,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz_aval,
x_aval,
scale_aval,
amax_aval,
*,
out_dtype,
scaling_mode,
......@@ -561,13 +639,16 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
te_dact_dbias_quantize_p abstract
"""
del act_enum, act_params
del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype
......@@ -576,6 +657,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32
assert amax_aval.dtype == jnp.float32
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused dact and quantization. Please do"
......@@ -655,6 +737,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz,
x,
scale,
amax,
*,
out_dtype,
scaling_mode,
......@@ -663,27 +746,42 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
te_dact_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, act_len, is_outer
dz_aval, x_aval, scale_aval = ctx.avals_in
del (
out_dtype,
scale_dtype,
act_len,
is_outer,
amax_scope,
transpose_batch_sequence,
)
dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
assert scale_aval.dtype == amax_aval.dtype == jnp.float32
return ffi.ffi_lowering(
BaseDActLuDBiasQuantizePrimitive.name,
operand_output_aliases={3: 4}, # donate amax buffer to updated_amax
)(
ctx,
dz,
x,
scale,
amax,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
is_dbias=is_dbias,
act_enum=int(act_enum),
act_params=act_params.to_ffi_lowering_dict(),
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
@staticmethod
......@@ -691,6 +789,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz,
x,
scale,
amax,
out_dtype,
scaling_mode,
is_2x,
......@@ -698,8 +797,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
te_dact_dbias_quantize_p impl
......@@ -711,6 +813,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz,
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
......@@ -718,8 +821,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
is_outer=False,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -747,17 +853,19 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
to describe batch rules for vmap
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
dz, x, scale = batched_args
_, x_bdim, scale_bdim = batch_dims
dz, x, scale, amax = batched_args
_, x_bdim, scale_bdim, _ = batch_dims
out_bdims = (
x_bdim, # rowwise output
......@@ -772,6 +880,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dz,
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
......@@ -780,6 +889,10 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
act_enum=act_enum,
act_len=act_len,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=is_outer,
),
out_bdims,
)
......@@ -793,14 +906,18 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
arg_infos,
result_infos,
):
del out_dtype, result_infos, act_enum, act_params
del scale_dtype, act_len, is_outer
del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling
del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
......@@ -869,8 +986,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
arg_infos,
result_infos,
......@@ -937,12 +1057,13 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
dbias_sharding,
)
def sharded_impl(dz, x, scale):
(out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
def sharded_impl(dz, x, scale, amax):
(out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = (
BaseDActLuDBiasQuantizePrimitive.impl(
dz,
x,
scale,
amax,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
......@@ -950,8 +1071,11 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
is_outer=True,
act_params=act_params,
output_amax_when_no_scaling=output_amax_when_no_scaling,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
is_outer=True,
)
)
if is_dbias:
......@@ -960,9 +1084,15 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
global_dbias = local_dbias
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(
local_updated_amax, mesh
)
elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
local_updated_amax, x_spec, transpose_batch_sequence, mesh
)
else:
global_updated_amax = local_amax
global_updated_amax = local_updated_amax
return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
......@@ -977,14 +1107,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
is_dbias,
act_enum,
act_len,
is_outer,
act_params,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params
del (
out_dtype,
scale_dtype,
act_enum,
act_len,
act_params,
is_outer,
output_amax_when_no_scaling,
mesh,
result_types,
amax_scope,
transpose_batch_sequence,
)
prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
......@@ -1006,7 +1152,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
amax = (prefix + "amax",)
return SdyShardingRule(
(dz_axes, x_axes, ("…2",)),
(dz_axes, x_axes, ("…2",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
......@@ -1092,6 +1238,8 @@ def act_lu(
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
......@@ -1108,6 +1256,8 @@ def act_lu(
If quantizer is provided:
A ScaledTensor containing the quantized activated input.
"""
# TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl()
# Do the same with dact_dbias_quantize() and layernorm_fwd()
act_type_id = ActivationEnum[activation_type].value
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
......@@ -1123,30 +1273,44 @@ def act_lu(
return _jax_act_lu(x, activation_type, quantizer, act_params)
# TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war(
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params
f=act_lu,
x=x,
activation_type=activation_type,
quantizer=quantizer,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
if war_output is not None:
return war_output
scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-2], x.shape[-1])
amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,)
if quantizer is None:
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind(
x,
scale,
amax,
out_dtype=x.dtype,
act_enum=act_type_id,
act_len=act_len,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
is_outer=True,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
out = out.reshape(output_shape)
# TODO(Phuong): ScaledTensorFactory to create NoScaledTensor
out = NoScaleTensor(
data=out,
amax=None,
amax=updated_amax if output_amax_when_no_scaling else None,
)
return out
......@@ -1157,6 +1321,9 @@ def act_lu(
activation_type=activation_type,
quantizer=None,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out,
......@@ -1164,6 +1331,7 @@ def act_lu(
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out
if isinstance(quantizer, DelayedScaleQuantizer):
......@@ -1178,14 +1346,18 @@ def act_lu(
) = ActLuPrimitive.outer_primitive.bind(
x,
scale,
amax,
out_dtype=quantizer.q_dtype,
act_enum=act_type_id,
act_len=act_len,
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
is_outer=True,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
quantizer.update(updated_amax)
......@@ -1209,6 +1381,9 @@ def quantize_dact_dbias(
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization.
......@@ -1232,7 +1407,8 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}"
)
scale = jnp.empty((), jnp.float32)
scale = jnp.empty((1,), jnp.float32)
amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,)
act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled() or (
......@@ -1240,10 +1416,11 @@ def quantize_dact_dbias(
):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
amax,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
......@@ -1253,8 +1430,11 @@ def quantize_dact_dbias(
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
output = output.astype(x.dtype)
dbias = None
......@@ -1263,7 +1443,7 @@ def quantize_dact_dbias(
output = NoScaleTensor(
data=output,
amax=None,
amax=updated_amax if output_amax_when_no_scaling else None,
)
return output, dbias
......@@ -1275,9 +1455,18 @@ def quantize_dact_dbias(
activation_type,
quantizer=None,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
return _quantize_dbias_impl(
out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
out.data,
quantizer,
is_dbias=True,
dq_dtype=x.dtype,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
is_gated = act_len == 2
......@@ -1292,6 +1481,9 @@ def quantize_dact_dbias(
quantizer=quantizer,
flatten_axis=-2,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
if war_output is not None:
return war_output
......@@ -1304,9 +1496,18 @@ def quantize_dact_dbias(
activation_type=activation_type,
quantizer=None,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, dbias = _quantize_dbias_impl(
out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
out,
is_dbias=is_dbias,
quantizer=quantizer,
dq_dtype=x.dtype,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, dbias
......@@ -1320,9 +1521,17 @@ def quantize_dact_dbias(
x.astype(jnp.float32),
activation_type=activation_type,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
out, dbias = _quantize_dbias_impl(
dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
dgated,
quantizer,
is_dbias=True,
dq_dtype=x.dtype,
flatten_axis=-2,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, dbias
......@@ -1337,6 +1546,7 @@ def quantize_dact_dbias(
dz,
x,
scale,
amax,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
......@@ -1344,8 +1554,11 @@ def quantize_dact_dbias(
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
......@@ -1375,6 +1588,9 @@ def dact_lu(
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[ActivationParams] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]:
"""
Backward pass for activation with optional quantization.
......@@ -1396,5 +1612,8 @@ def dact_lu(
is_dbias=False,
quantizer=quantizer,
act_params=act_params,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
return output
......@@ -92,7 +92,7 @@ class NormFwdPrimitive(BasePrimitive):
name = "te_norm_forward_ffi"
multiple_results = True
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11)
impl_static_args = (5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
......@@ -100,6 +100,7 @@ class NormFwdPrimitive(BasePrimitive):
def abstract(
x_aval,
scale_aval,
amax_aval,
gamma_aval,
beta_aval,
*,
......@@ -110,15 +111,27 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
LayerNorm fwd inner primitive abstract
"""
del amax_scope, transpose_batch_sequence
assert not output_amax_when_no_scaling or (
scaling_mode == ScalingMode.NO_SCALING.value
and not is_norm_fwd_cudnn_enabled(scaling_mode)
), (
f"scaling_mode = {scaling_mode},"
f" use_cudnn_norm_fwd={is_norm_fwd_cudnn_enabled(scaling_mode)}"
)
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert amax_aval is None or amax_aval.dtype == jnp.float32
assert (
scaling_mode != ScalingMode.MXFP8_1D_SCALING.value
......@@ -220,6 +233,7 @@ class NormFwdPrimitive(BasePrimitive):
ctx,
x,
scale,
amax,
gamma,
beta,
*,
......@@ -230,16 +244,20 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
LayerNorm fwd lowering rules
"""
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in
del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence
x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert amax_aval is None or amax_aval.dtype == jnp.float32
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
......@@ -251,10 +269,14 @@ class NormFwdPrimitive(BasePrimitive):
assert g_shape == b_shape
sm_margin = get_forward_sm_margin()
return ffi.ffi_lowering(NormFwdPrimitive.name)(
return ffi.ffi_lowering(
NormFwdPrimitive.name,
operand_output_aliases={2: 4}, # amax <-> updated_amax
)(
ctx,
x,
scale,
amax,
gamma,
beta,
norm_type=norm_type.value,
......@@ -263,12 +285,14 @@ class NormFwdPrimitive(BasePrimitive):
sm_margin=sm_margin,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)
@staticmethod
def impl(
x,
scale,
amax,
gamma,
beta,
norm_type,
......@@ -278,6 +302,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
......@@ -297,6 +324,7 @@ class NormFwdPrimitive(BasePrimitive):
) = NormFwdPrimitive.inner_primitive.bind(
x,
scale,
amax,
gamma,
beta,
norm_type=norm_type,
......@@ -306,6 +334,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -341,16 +372,18 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
):
"""
to describe batch rules for vmap
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert NormFwdPrimitive.outer_primitive is not None
x, scale, gamma, beta = batched_args
x_bdim, scale_bdim, _, _ = batch_dims
x, scale, amax, gamma, beta = batched_args
x_bdim, scale_bdim, _, _, _ = batch_dims
out_bdims = (
x_bdim, # rowwise output
......@@ -363,8 +396,9 @@ class NormFwdPrimitive(BasePrimitive):
)
return (
NormFwdPrimitive.outer_primitive.bind(
scale,
x,
scale,
amax,
gamma,
beta,
norm_type=norm_type,
......@@ -374,6 +408,10 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=is_outer,
),
out_bdims,
)
......@@ -387,15 +425,19 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
arg_infos,
result_infos,
):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, is_outer
del scale_dtype, is_outer, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
amax_spec = get_padded_spec(arg_infos[2])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None:
warnings.warn(
......@@ -415,9 +457,9 @@ class NormFwdPrimitive(BasePrimitive):
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
scale_inv_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
......@@ -445,6 +487,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
arg_infos,
......@@ -453,8 +498,9 @@ class NormFwdPrimitive(BasePrimitive):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3])
amax_spec = get_padded_spec(arg_infos[2])
g_spec = get_padded_spec(arg_infos[3])
b_spec = get_padded_spec(arg_infos[4])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None:
......@@ -485,9 +531,9 @@ class NormFwdPrimitive(BasePrimitive):
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
scale_inv_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
......@@ -499,10 +545,10 @@ class NormFwdPrimitive(BasePrimitive):
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Enforce no sharding of hidden dim for x, gamma and beta
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x")
arg_shardings[2] = NamedSharding(
arg_shardings[3] = NamedSharding(
mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma"
)
arg_shardings[3] = NamedSharding(
arg_shardings[4] = NamedSharding(
mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta"
)
arg_shardings = tuple(arg_shardings)
......@@ -516,19 +562,20 @@ class NormFwdPrimitive(BasePrimitive):
rsigma_sharding,
)
def sharded_impl(x, scale, gamma, beta):
def sharded_impl(x, scale, amax, gamma, beta):
# expect tp and dp giving same shape, or tp being same shape as global
(
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
local_amax,
local_updated_amax,
local_mu,
local_rsigma,
) = NormFwdPrimitive.impl(
x,
scale,
amax,
gamma,
beta,
norm_type=norm_type,
......@@ -538,12 +585,21 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(
local_updated_amax, mesh
)
elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
local_updated_amax, x_spec, transpose_batch_sequence, mesh
)
else:
global_updated_amax = local_amax
global_updated_amax = local_updated_amax
return (
local_x,
......@@ -566,6 +622,9 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
value_types,
......@@ -576,6 +635,9 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scale_dtype,
amax_scope,
transpose_batch_sequence,
output_amax_when_no_scaling,
is_outer,
mesh,
result_types,
......@@ -594,7 +656,7 @@ class NormFwdPrimitive(BasePrimitive):
amax = (prefix + "amax",)
return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)),
(x_axes, ("…1",), amax, ("…2",), ("…3",)),
(
out,
colwise_out,
......@@ -882,6 +944,8 @@ def layernorm_fwd(
epsilon: float,
quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization.
......@@ -896,6 +960,7 @@ def layernorm_fwd(
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False.
Returns:
A tuple containing:
......@@ -918,10 +983,12 @@ def layernorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
amax = jnp.zeros((1,), dtype=jnp.float32)
if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
output, _, _, _, updated_amax, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
amax,
gamma,
beta,
norm_type=NVTE_Norm_Type.LayerNorm,
......@@ -931,18 +998,37 @@ def layernorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=False,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
return NoScaleTensor(data=output, amax=None), mu, rsigma
# cuDNN does not support amax output for non quantized output
updated_amax = (
updated_amax
if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING)
else None
)
return NoScaleTensor(data=output, amax=updated_amax), mu, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, mu, rsigma = layernorm_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence
)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, mu, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
......@@ -954,6 +1040,9 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out,
......@@ -961,6 +1050,7 @@ def layernorm_fwd(
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, mu, rsigma
......@@ -979,6 +1069,7 @@ def layernorm_fwd(
) = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
amax,
gamma,
beta,
norm_type=NVTE_Norm_Type.LayerNorm,
......@@ -988,6 +1079,9 @@ def layernorm_fwd(
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
quantizer.update(updated_amax)
......@@ -1091,7 +1185,9 @@ def rmsnorm_fwd(
zero_centered_gamma: bool,
epsilon: float,
quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
amax_scope: AmaxScope = AmaxScope.TPSP,
transpose_batch_sequence: bool = False,
output_amax_when_no_scaling: bool = False,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
"""Root mean square normalization forward pass with optional quantization.
......@@ -1104,6 +1200,7 @@ def rmsnorm_fwd(
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False.
Returns:
A tuple containing:
......@@ -1127,12 +1224,14 @@ def rmsnorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
amax = jnp.zeros((1,), dtype=jnp.float32)
beta = jnp.ones((1,), dtype=jnp.float32)
if quantizer is None:
output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind(
output, _, _, _, updated_amax, _, rsigma = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
amax,
gamma,
beta,
norm_type=NVTE_Norm_Type.RMSNorm,
......@@ -1142,16 +1241,39 @@ def rmsnorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
return NoScaleTensor(data=output, amax=None), rsigma
# cuDNN does not support amax output for non quantized output
updated_amax = (
updated_amax
if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING)
else None
)
return NoScaleTensor(data=output, amax=updated_amax), rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out.data, quantizer)
out, rsigma = rmsnorm_fwd(
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=False,
)
out, _ = _quantize_dbias_impl(
out.data,
quantizer,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
......@@ -1162,13 +1284,17 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
quantizer=None,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=True,
)
out, _ = _quantize_dbias_impl(
out.data,
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
return out, rsigma
......@@ -1187,6 +1313,7 @@ def rmsnorm_fwd(
) = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
amax,
gamma,
beta,
norm_type=NVTE_Norm_Type.RMSNorm,
......@@ -1196,6 +1323,9 @@ def rmsnorm_fwd(
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
output_amax_when_no_scaling=output_amax_when_no_scaling,
is_outer=True,
)
quantizer.update(updated_amax)
......@@ -1294,6 +1424,7 @@ def normalization_fwd(
norm_type: str,
quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
transpose_batch_sequence: bool = False,
):
"""Common wrapper for normalization forward pass.
......@@ -1311,6 +1442,7 @@ def normalization_fwd(
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False.
Returns:
A tuple containing:
......@@ -1336,6 +1468,7 @@ def normalization_fwd(
epsilon,
quantizer,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
elif norm_type == "rmsnorm":
assert (
......@@ -1348,6 +1481,7 @@ def normalization_fwd(
epsilon,
quantizer,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
mu = None
else:
......
......@@ -543,6 +543,18 @@ class AmaxScope(Enum):
TPSP = 2
FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive):
"""
......@@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
impl_static_args = (
1,
2,
) # amax_scope, batch_sequence_transpose
) # amax_scope, transpose_batch_sequence
inner_primitive = None
outer_primitive = None
......@@ -563,12 +575,12 @@ class AmaxCalculationPrimitive(BasePrimitive):
x_aval,
*,
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
):
"""
amax calcuation abstract
"""
del amax_scope, batch_sequence_transpose
del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -580,19 +592,19 @@ class AmaxCalculationPrimitive(BasePrimitive):
def impl(
x,
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
):
"""
amax calcuation implementation
"""
del amax_scope, batch_sequence_transpose
del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
......@@ -600,7 +612,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused.
del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
......@@ -611,7 +623,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
@staticmethod
def partition(
amax_scope,
batch_sequence_transpose,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
......@@ -631,16 +643,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
gmesh = global_mesh_resource()
sequence_dim = 0 if batch_sequence_transpose else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if amax_scope is AmaxScope.FSDP:
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
......@@ -648,11 +655,11 @@ class AmaxCalculationPrimitive(BasePrimitive):
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types):
def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, batch_sequence_transpose, mesh, result_types
del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
......@@ -709,7 +716,7 @@ def _quantize_dbias_impl(
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -755,12 +762,12 @@ def _quantize_dbias_impl(
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
scale = jnp.empty((), jnp.float32)
scale = jnp.empty((1,), jnp.float32)
amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
......@@ -771,7 +778,7 @@ def _quantize_dbias_impl(
amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
......@@ -845,7 +852,7 @@ def quantize(
quantizer: Quantizer,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -866,7 +873,7 @@ def quantize(
quantizer=quantizer,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
return out
......@@ -877,7 +884,7 @@ def quantize_dbias(
is_dbias: bool = True,
flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -904,7 +911,7 @@ def quantize_dbias(
is_dbias=is_dbias,
flatten_axis=flatten_axis,
amax_scope=amax_scope,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
......
......@@ -15,13 +15,14 @@ namespace transformer_engine {
namespace jax {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params) {
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
......@@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto *output = output_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto input_dims = input_buf.dimensions();
auto m = product(input_dims, 0, input_dims.size() - 2);
......@@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_trans_shape = std::vector<size_t>{static_cast<size_t>(n), m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
......@@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
......@@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"),
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int, act_params);
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int,
ActivationConfig act_params, bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf,
output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf,
updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params,
output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
......@@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<ActivationConfig>("act_params"));
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
......@@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias,
ActivationConfig act_params) {
Buffer_Type amax_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
......@@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
......@@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
......@@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"),
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x,
bool is_dbias, ActivationConfig act_params) {
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params);
act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params,
output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
......@@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<ActivationConfig>("act_params"));
.Attr<ActivationConfig>("act_params")
.Attr<bool>("output_amax_when_no_scaling"));
} // namespace jax
} // namespace transformer_engine
......@@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
......@@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
}
Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf,
int norm_type, bool zero_centered_gamma, double epsilon,
int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) {
Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma,
double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode,
bool is_2x, bool output_amax_when_no_scaling) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
......@@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto *output = output_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto *workspace = wkspace_buf->untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased");
auto _norm_type = static_cast<NVTE_Norm_Type>(norm_type);
auto _is_2x = static_cast<bool>(is_2x);
......@@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
(scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) {
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
......@@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
if (_is_2x) {
......@@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
......@@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"),
.Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"),
FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type,
Buffer_Type amax_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf,
Result_Type colwise_output_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
return wrapInStreamCapture(
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x);
JAXX_Scaling_Mode scaling_mode, bool is_2x,
bool output_amax_when_no_scaling) {
return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf,
gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf,
colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin,
scaling_mode, is_2x, output_amax_when_no_scaling);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
......@@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // updated_amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
......@@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
.Attr<bool>("is_2x")
.Attr<bool>("output_amax_when_no_scaling"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type,
......
......@@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
NVTE_CHECK(amax == updated_amax && amax != nullptr,
"amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
......
......@@ -63,7 +63,7 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
output_axes: Tuple[str, ...] = None,
......@@ -81,7 +81,7 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
......@@ -91,8 +91,8 @@ def dense(
Returns:
Transformed output tensor
"""
if batch_sequence_transpose:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!")
if transpose_batch_sequence:
warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
......@@ -103,7 +103,7 @@ def dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -119,7 +119,7 @@ def _dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -136,7 +136,7 @@ def _dense(
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix
......@@ -151,7 +151,7 @@ def _dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -166,7 +166,7 @@ def _dense_fwd_rule(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -197,7 +197,7 @@ def _dense_fwd_rule(
flatten_axis=flatten_axis_x,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
......@@ -215,7 +215,7 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
......@@ -240,7 +240,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule(
contracting_dims,
batch_sequence_transpose,
transpose_batch_sequence,
input_axes,
kernel_axes,
output_axes,
......@@ -274,7 +274,7 @@ def _dense_bwd_rule(
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
batch_sequence_transpose=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
# GEMM NT
......@@ -291,7 +291,7 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set.backward,
)
......@@ -305,7 +305,7 @@ def _dense_bwd_rule(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......
......@@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
"""
features: Union[Iterable[int], int]
......@@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
input_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -512,6 +515,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
if self.enable_low_rank_adaptation:
......@@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
"""
features: Union[Iterable[int], int]
......@@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -768,6 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
......@@ -775,6 +783,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
y,
kernel,
contracting_dims=(axis, contract_ind),
transpose_batch_sequence=self.transpose_batch_sequence,
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
......@@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence: bool, default = False
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
"""
intermediate_dim: int = 2048
......@@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -1160,6 +1172,7 @@ class LayerNormMLP(TransformerEngineBase):
activation_type=normalized_acts,
activation_params=self.activation_params,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
transpose_batch_sequence=self.transpose_batch_sequence,
)
out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple)
......@@ -1178,6 +1191,7 @@ class LayerNormMLP(TransformerEngineBase):
dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes)
......@@ -1188,6 +1202,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
if self.enable_low_rank_adaptation:
......@@ -1260,6 +1275,7 @@ class LayerNormMLP(TransformerEngineBase):
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
transpose_batch_sequence=self.transpose_batch_sequence,
)
if self.enable_low_rank_adaptation:
......
......@@ -1207,6 +1207,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="qkv",
dtype=self.dtype,
)(inputs_q)
......@@ -1234,6 +1235,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query",
)(inputs_q)
......@@ -1252,6 +1254,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
transpose_batch_sequence=self.transpose_batch_sequence,
name="kv",
dtype=self.dtype,
)(inputs_kv)
......@@ -1292,6 +1295,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
transpose_batch_sequence=self.transpose_batch_sequence,
name="query",
)(inputs_q)
......@@ -2070,6 +2074,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
transpose_batch_sequence=self.transpose_batch_sequence,
name="mlp",
)(mlp_input, deterministic=deterministic)
......
......@@ -16,6 +16,7 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .cpp_extensions.quantization import AmaxScope
from .quantize import (
QuantizerSet,
......@@ -35,6 +36,7 @@ def layernorm_dense(
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
transpose_batch_sequence: bool = False,
layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
......@@ -55,6 +57,7 @@ def layernorm_dense(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
......@@ -83,6 +86,7 @@ def layernorm_dense(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -100,6 +104,7 @@ def layernorm_dense(
8,
9,
10,
11,
),
)
def _layernorm_dense(
......@@ -111,6 +116,7 @@ def _layernorm_dense(
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
transpose_batch_sequence: bool,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
......@@ -131,6 +137,7 @@ def _layernorm_dense(
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers
......@@ -147,6 +154,7 @@ def _layernorm_dense(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule(
epsilon,
norm_type,
quantizer=quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
......@@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule(
kernel,
flatten_axis=flatten_axis,
quantizer=quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
......@@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=transpose_batch_sequence,
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
......@@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
transpose_batch_sequence,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
......@@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule(
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......@@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule(
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim),
transpose_batch_sequence=transpose_batch_sequence,
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -41,7 +41,7 @@ def layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
batch_sequence_transpose: bool = False,
transpose_batch_sequence: bool = False,
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
......@@ -78,7 +78,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
......@@ -130,7 +130,7 @@ def layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -158,7 +158,7 @@ def _layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
batch_sequence_transpose: bool,
transpose_batch_sequence: bool,
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
......@@ -188,7 +188,7 @@ def _layernorm_mlp(
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
transpose_batch_sequence: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
......@@ -214,7 +214,7 @@ def _layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule(
norm_type,
quantizer=ffn1_quantizer_set.x,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP
kernel_1,
flatten_axis=-2,
quantizer=ffn1_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
)
# NN GEMM
......@@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
......@@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule(
if activation_params
else None
),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
......@@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule(
kernel_2,
quantizer=ffn2_quantizer_set.kernel,
amax_scope=AmaxScope.FSDP,
transpose_batch_sequence=transpose_batch_sequence,
)
# NN GEMM
......@@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
......@@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
transpose_batch_sequence,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule(
is_dbias=use_bias_2,
quantizer=ffn1_quantizer_set.dgrad,
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_2.backward,
)
......@@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule(
if activation_params
else None
),
amax_scope=AmaxScope.TPSP,
transpose_batch_sequence=transpose_batch_sequence,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
......@@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op_set_1.backward,
)
......@@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
transpose_batch_sequence=transpose_batch_sequence,
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment