"vscode:/vscode.git/clone" did not exist on "46075b98abac3f6472908ca67484c1a10887da15"
Unverified Commit b17f3f4e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Make primitive names more granular for better disabling granularity (#1811)



Make primitive names more granular for better disabling granularity
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 00328ac7
......@@ -453,7 +453,7 @@ register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive):
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
"""
DActLu DBias Cast Transpose Primitive
"""
......@@ -561,7 +561,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p outer abstract
"""
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
......@@ -589,7 +589,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
ctx,
dz,
x,
......@@ -618,9 +618,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p impl
"""
del is_outer
assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.inner_primitive.bind(
BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
dz,
x,
scale,
......@@ -666,7 +666,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
dz, x, scale = batched_args
_, x_bdim, scale_bdim = batch_dims
......@@ -679,7 +679,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_bdim, # dbias
)
return (
DActLuDBiasQuantizePrimitive.outer_primitive.bind(
BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
......@@ -718,7 +718,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
), "Partitioned current tensor scaling is not yet supported."
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
......@@ -728,14 +728,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
mesh,
PartitionSpec(*colwise_x_spec),
desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
)
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
desc="BaseDActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -748,15 +750,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
)
return (
out_sharding,
......@@ -786,7 +788,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
)
if is_2x:
......@@ -797,14 +799,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
mesh,
PartitionSpec(*colwise_x_spec),
desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
)
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
desc="BaseDActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -827,7 +831,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Ensure dz and x are partitioned the same way.
arg_shardings[0] = NamedSharding(
mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]), desc="DActLuDBiasQuantizePrimitive.dz"
mesh,
PartitionSpec(*x_spec[:-2], x_spec[-1]),
desc="BaseDActLuDBiasQuantizePrimitive.dz",
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
......@@ -841,7 +847,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
def sharded_impl(dz, x, scale):
(out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
DActLuDBiasQuantizePrimitive.impl(
BaseDActLuDBiasQuantizePrimitive.impl(
dz,
x,
scale,
......@@ -887,7 +893,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2
x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
......@@ -909,7 +915,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
register_primitive(DActLuDBiasQuantizePrimitive)
register_primitive(BaseDActLuDBiasQuantizePrimitive)
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
......@@ -1099,7 +1113,8 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}"
)
if not DActLuDBiasQuantizePrimitive.enabled():
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet
......@@ -1135,7 +1150,7 @@ def quantize_dact_dbias(
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
......@@ -1188,7 +1203,7 @@ def quantize_dact_dbias(
colwise_scale_inv,
updated_amax,
dbias,
) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
) = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
......
......@@ -33,14 +33,15 @@ class BasePrimitive(metaclass=ABCMeta):
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.name` does not fully match the
A custom call is marked as disabled if the `cls.__name__` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.name) is not None
is_enabled = pattern.fullmatch(cls.__name__) is not None
return is_enabled
@staticmethod
......
......@@ -44,7 +44,7 @@ else:
__all__ = ["quantize", "quantize_dbias"]
class DBiasQuantizePrimitive(BasePrimitive):
class BaseDBiasQuantizePrimitive(BasePrimitive):
"""
Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
"""
......@@ -155,7 +155,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
......@@ -179,7 +179,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
ctx,
x,
scale,
......@@ -205,7 +205,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
te_dbias_quantize_p implementation
"""
del is_outer
assert DBiasQuantizePrimitive.inner_primitive is not None
assert BaseDBiasQuantizePrimitive.inner_primitive is not None
(
out,
colwise_out,
......@@ -214,7 +214,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.inner_primitive.bind(
) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
x,
scale,
out_dtype=out_dtype,
......@@ -262,14 +262,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert DBiasQuantizePrimitive.outer_primitive is not None
assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return (
DBiasQuantizePrimitive.outer_primitive.bind(
BaseDBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
out_dtype=out_dtype,
......@@ -302,7 +302,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
......@@ -314,14 +314,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
)
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -334,15 +334,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
return (
......@@ -374,7 +374,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
......@@ -386,14 +386,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
)
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -406,15 +406,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
......@@ -435,7 +435,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
local_colwise_scale_inv,
local_amax,
local_dbias,
) = DBiasQuantizePrimitive.impl(
) = BaseDBiasQuantizePrimitive.impl(
x,
scale,
out_dtype=out_dtype,
......@@ -485,7 +485,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape),
unique_var="DBiasQuantizePrimitive_i",
unique_var="BaseDBiasQuantizePrimitive_i",
flatten_axis=flatten_axis,
)
......@@ -512,7 +512,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
register_primitive(DBiasQuantizePrimitive)
register_primitive(BaseDBiasQuantizePrimitive)
class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
def _jax_quantize(
......@@ -565,7 +573,8 @@ def _quantize_dbias_impl(
dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled():
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if not PrimitiveClass.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
......@@ -627,7 +636,7 @@ def _quantize_dbias_impl(
colwise_scale_inv,
updated_amax,
dbias,
) = DBiasQuantizePrimitive.outer_primitive.bind(
) = PrimitiveClass.outer_primitive.bind(
x,
scale,
out_dtype=quantizer.q_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment