Unverified Commit b17f3f4e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Make primitive names more granular for better disabling granularity (#1811)



Make primitive names more granular for better disabling granularity
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 00328ac7
...@@ -453,7 +453,7 @@ register_primitive(ActLuPrimitive) ...@@ -453,7 +453,7 @@ register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout # TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive): class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
DActLu DBias Cast Transpose Primitive DActLu DBias Cast Transpose Primitive
""" """
...@@ -561,7 +561,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -561,7 +561,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p outer abstract te_dact_dbias_quantize_p outer abstract
""" """
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs) BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
) )
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
...@@ -589,7 +589,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -589,7 +589,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)( return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
ctx, ctx,
dz, dz,
x, x,
...@@ -618,9 +618,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -618,9 +618,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p impl te_dact_dbias_quantize_p impl
""" """
del is_outer del is_outer
assert DActLuDBiasQuantizePrimitive.inner_primitive is not None assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = ( (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.inner_primitive.bind( BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
dz, dz,
x, x,
scale, scale,
...@@ -666,7 +666,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -666,7 +666,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
del is_outer del is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert DActLuDBiasQuantizePrimitive.outer_primitive is not None assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
dz, x, scale = batched_args dz, x, scale = batched_args
_, x_bdim, scale_bdim = batch_dims _, x_bdim, scale_bdim = batch_dims
...@@ -679,7 +679,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -679,7 +679,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_bdim, # dbias x_bdim, # dbias
) )
return ( return (
DActLuDBiasQuantizePrimitive.outer_primitive.bind( BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz, dz,
x, x,
scale, scale,
...@@ -718,7 +718,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -718,7 +718,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
), "Partitioned current tensor scaling is not yet supported." ), "Partitioned current tensor scaling is not yet supported."
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
...@@ -728,14 +728,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -728,14 +728,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else: else:
colwise_x_spec = (None,) colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" mesh,
PartitionSpec(*colwise_x_spec),
desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
) )
dbias_spec = x_spec[-2:] if is_dbias else (None,) dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding( dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*dbias_spec), PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias", desc="BaseDActLuDBiasQuantizePrimitive.dbias",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
...@@ -748,15 +750,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -748,15 +750,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
) )
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax" mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
) )
colwise_scale_inv_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*colwise_scale_inv_spec), PartitionSpec(*colwise_scale_inv_spec),
desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
) )
return ( return (
out_sharding, out_sharding,
...@@ -786,7 +788,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -786,7 +788,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_spec = get_padded_spec(arg_infos[2]) scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if is_2x:
...@@ -797,14 +799,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -797,14 +799,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else: else:
colwise_x_spec = (None,) colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" mesh,
PartitionSpec(*colwise_x_spec),
desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
) )
dbias_spec = x_spec[-2:] if is_dbias else (None,) dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding( dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*dbias_spec), PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias", desc="BaseDActLuDBiasQuantizePrimitive.dbias",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
...@@ -827,7 +831,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -827,7 +831,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
arg_shardings = list(arg_i.sharding for arg_i in arg_infos) arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Ensure dz and x are partitioned the same way. # Ensure dz and x are partitioned the same way.
arg_shardings[0] = NamedSharding( arg_shardings[0] = NamedSharding(
mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]), desc="DActLuDBiasQuantizePrimitive.dz" mesh,
PartitionSpec(*x_spec[:-2], x_spec[-1]),
desc="BaseDActLuDBiasQuantizePrimitive.dz",
) )
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
...@@ -841,7 +847,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -841,7 +847,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
def sharded_impl(dz, x, scale): def sharded_impl(dz, x, scale):
(out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = ( (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
DActLuDBiasQuantizePrimitive.impl( BaseDActLuDBiasQuantizePrimitive.impl(
dz, dz,
x, x,
scale, scale,
...@@ -887,7 +893,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -887,7 +893,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_rank = len(value_types[1].shape) x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2 x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
out = x_axes out = x_axes
...@@ -909,7 +915,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -909,7 +915,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
) )
register_primitive(DActLuDBiasQuantizePrimitive) register_primitive(BaseDActLuDBiasQuantizePrimitive)
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
...@@ -1099,7 +1113,8 @@ def quantize_dact_dbias( ...@@ -1099,7 +1113,8 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}" f" {x.shape} and act_len {act_len}"
) )
if not DActLuDBiasQuantizePrimitive.enabled(): PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
...@@ -1135,7 +1150,7 @@ def quantize_dact_dbias( ...@@ -1135,7 +1150,7 @@ def quantize_dact_dbias(
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type]
if quantizer is None: if quantizer is None:
output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind( output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz, dz,
x, x,
scale, scale,
...@@ -1188,7 +1203,7 @@ def quantize_dact_dbias( ...@@ -1188,7 +1203,7 @@ def quantize_dact_dbias(
colwise_scale_inv, colwise_scale_inv,
updated_amax, updated_amax,
dbias, dbias,
) = DActLuDBiasQuantizePrimitive.outer_primitive.bind( ) = PrimitiveClass.outer_primitive.bind(
dz, dz,
x, x,
scale, scale,
......
...@@ -33,14 +33,15 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -33,14 +33,15 @@ class BasePrimitive(metaclass=ABCMeta):
@classmethod @classmethod
def enabled(cls): def enabled(cls):
""" """
A custom call is marked as disabled if the `cls.name` does not fully match the A custom call is marked as disabled if the `cls.__name__` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern. `NVTE_JAX_CUSTOM_CALLS_RE` pattern.
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`. For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
""" """
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern) pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.name) is not None is_enabled = pattern.fullmatch(cls.__name__) is not None
return is_enabled return is_enabled
@staticmethod @staticmethod
......
...@@ -44,7 +44,7 @@ else: ...@@ -44,7 +44,7 @@ else:
__all__ = ["quantize", "quantize_dbias"] __all__ = ["quantize", "quantize_dbias"]
class DBiasQuantizePrimitive(BasePrimitive): class BaseDBiasQuantizePrimitive(BasePrimitive):
""" """
Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
""" """
...@@ -155,7 +155,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -155,7 +155,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
updated_amax, updated_amax,
dbias, dbias,
_, _,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs) ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod @staticmethod
...@@ -179,7 +179,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -179,7 +179,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
x_aval, scale_aval = ctx.avals_in x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DBiasQuantizePrimitive.name)( return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
ctx, ctx,
x, x,
scale, scale,
...@@ -205,7 +205,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -205,7 +205,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
te_dbias_quantize_p implementation te_dbias_quantize_p implementation
""" """
del is_outer del is_outer
assert DBiasQuantizePrimitive.inner_primitive is not None assert BaseDBiasQuantizePrimitive.inner_primitive is not None
( (
out, out,
colwise_out, colwise_out,
...@@ -214,7 +214,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -214,7 +214,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
updated_amax, updated_amax,
dbias, dbias,
_, _,
) = DBiasQuantizePrimitive.inner_primitive.bind( ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
x, x,
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
...@@ -262,14 +262,14 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -262,14 +262,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
""" """
del is_outer del is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert DBiasQuantizePrimitive.outer_primitive is not None assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args x, scale = batched_args
x_bdim, scale_bdim = batch_dims x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim amax_bdim = scale_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return ( return (
DBiasQuantizePrimitive.outer_primitive.bind( BaseDBiasQuantizePrimitive.outer_primitive.bind(
x, x,
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
...@@ -302,7 +302,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -302,7 +302,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding", desc="BaseDBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling(): if ScalingMode(scaling_mode).is_tensor_scaling():
...@@ -314,14 +314,14 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -314,14 +314,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*colwise_out_spec), PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding", desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
) )
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding( dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*dbias_spec), PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding", desc="BaseDBiasQuantizePrimitive.dbias_sharding",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
...@@ -334,15 +334,15 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -334,15 +334,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
) )
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
) )
colwise_scale_inv_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*colwise_scale_inv_spec), PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
) )
return ( return (
...@@ -374,7 +374,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -374,7 +374,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec), PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding", desc="BaseDBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling(): if ScalingMode(scaling_mode).is_tensor_scaling():
...@@ -386,14 +386,14 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -386,14 +386,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*colwise_out_spec), PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding", desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
) )
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding( dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*dbias_spec), PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding", desc="BaseDBiasQuantizePrimitive.dbias_sharding",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
...@@ -406,15 +406,15 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -406,15 +406,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
) )
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax" mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
) )
colwise_scale_inv_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*colwise_scale_inv_spec), PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
) )
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
...@@ -435,7 +435,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -435,7 +435,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
local_colwise_scale_inv, local_colwise_scale_inv,
local_amax, local_amax,
local_dbias, local_dbias,
) = DBiasQuantizePrimitive.impl( ) = BaseDBiasQuantizePrimitive.impl(
x, x,
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
...@@ -485,7 +485,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -485,7 +485,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), len(value_types[0].shape),
unique_var="DBiasQuantizePrimitive_i", unique_var="BaseDBiasQuantizePrimitive_i",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -512,7 +512,15 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -512,7 +512,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
) )
register_primitive(DBiasQuantizePrimitive) register_primitive(BaseDBiasQuantizePrimitive)
class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
def _jax_quantize( def _jax_quantize(
...@@ -565,7 +573,8 @@ def _quantize_dbias_impl( ...@@ -565,7 +573,8 @@ def _quantize_dbias_impl(
dq_dtype = dq_dtype or x.dtype dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled(): PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
...@@ -627,7 +636,7 @@ def _quantize_dbias_impl( ...@@ -627,7 +636,7 @@ def _quantize_dbias_impl(
colwise_scale_inv, colwise_scale_inv,
updated_amax, updated_amax,
dbias, dbias,
) = DBiasQuantizePrimitive.outer_primitive.bind( ) = PrimitiveClass.outer_primitive.bind(
x, x,
scale, scale,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment