Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -188,7 +188,7 @@ class ReorderStrategy(Enum):
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
mulitple of 2 * cp_size.
multiple of 2 * cp_size.
Examples:
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
......@@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
def is_fused_attn_kernel_available(
is_training,
q_dtype,
kv_dtype,
qkv_layout,
......@@ -287,7 +288,8 @@ def is_fused_attn_kernel_available(
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
head_dim_qk,
head_dim_v,
window_size: Optional[Tuple[int, int]] = None,
):
"""
......@@ -296,6 +298,7 @@ def is_fused_attn_kernel_available(
def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
is_training,
q_dtype,
kv_dtype,
qkv_layout,
......@@ -306,7 +309,8 @@ def is_fused_attn_kernel_available(
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
head_dim_qk,
head_dim_v,
(-1, -1) if window_size is None else window_size,
)
......@@ -489,7 +493,7 @@ def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
@jax.tree_util.register_pytree_node_class
class SequenceDescriptor:
"""A class to descibe the sequences with flexible initialization.
"""A class to describe the sequences with flexible initialization.
- SequenceDescriptor.from_seqlens
For non-THD (non-packed) cases, where each batch has only 1 sequence.
- SequenceDescriptor.from_seqlens_and_offsets
......
......@@ -453,7 +453,7 @@ register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive):
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
"""
DActLu DBias Cast Transpose Primitive
"""
......@@ -561,7 +561,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p outer abstract
"""
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
......@@ -589,7 +589,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
ctx,
dz,
x,
......@@ -618,9 +618,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p impl
"""
del is_outer
assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.inner_primitive.bind(
BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
dz,
x,
scale,
......@@ -666,7 +666,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
dz, x, scale = batched_args
_, x_bdim, scale_bdim = batch_dims
......@@ -679,7 +679,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_bdim, # dbias
)
return (
DActLuDBiasQuantizePrimitive.outer_primitive.bind(
BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
......@@ -718,7 +718,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
), "Partitioned current tensor scaling is not yet supported."
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
......@@ -728,14 +728,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
mesh,
PartitionSpec(*colwise_x_spec),
desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
)
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
desc="BaseDActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -748,15 +750,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
)
return (
out_sharding,
......@@ -786,7 +788,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
)
if is_2x:
......@@ -797,14 +799,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
mesh,
PartitionSpec(*colwise_x_spec),
desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
)
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
desc="BaseDActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -827,7 +831,9 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Ensure dz and x are partitioned the same way.
arg_shardings[0] = NamedSharding(
mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]), desc="DActLuDBiasQuantizePrimitive.dz"
mesh,
PartitionSpec(*x_spec[:-2], x_spec[-1]),
desc="BaseDActLuDBiasQuantizePrimitive.dz",
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
......@@ -841,7 +847,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
def sharded_impl(dz, x, scale):
(out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
DActLuDBiasQuantizePrimitive.impl(
BaseDActLuDBiasQuantizePrimitive.impl(
dz,
x,
scale,
......@@ -887,7 +893,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2
x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
......@@ -909,7 +915,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
register_primitive(DActLuDBiasQuantizePrimitive)
register_primitive(BaseDActLuDBiasQuantizePrimitive)
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
......@@ -1099,7 +1113,8 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}"
)
if not DActLuDBiasQuantizePrimitive.enabled():
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet
......@@ -1135,7 +1150,7 @@ def quantize_dact_dbias(
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
......@@ -1188,7 +1203,7 @@ def quantize_dact_dbias(
colwise_scale_inv,
updated_amax,
dbias,
) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
) = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
......
......@@ -41,6 +41,7 @@ from ..sharding import (
all_reduce_sum_along_dp_fsdp,
get_mesh_axis_size,
get_mesh_axis_rank,
get_mesh_axis_rank_host,
get_all_mesh_axes,
num_of_devices,
with_sharding_constraint,
......@@ -74,6 +75,7 @@ __all__ = [
"window_size",
"context_parallel_load_balanced",
"cp_axis",
"cp_striped_window_size",
],
)
@dataclass(frozen=True)
......@@ -92,6 +94,7 @@ class _FusedAttnConfig:
window_size: Tuple[int, int]
context_parallel_load_balanced: bool
cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA
@dataclass(frozen=True)
......@@ -100,6 +103,7 @@ class FusedAttnHelper:
Helper for the fused attention backend
"""
is_training: bool
q_dtype: jnp.dtype
kv_dtype: jnp.dtype
qkv_layout: QKVLayout
......@@ -110,7 +114,8 @@ class FusedAttnHelper:
kv_num_heads: int
q_max_seqlen: int
kv_max_seqlen: int
head_dim: int
head_dim_qk: int
head_dim_v: int
window_size: Tuple[int, int]
def is_fused_attn_kernel_available(self):
......@@ -120,6 +125,7 @@ class FusedAttnHelper:
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
self.is_training,
jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout.value,
......@@ -130,7 +136,8 @@ class FusedAttnHelper:
self.kv_num_heads,
self.q_max_seqlen,
self.kv_max_seqlen,
self.head_dim,
self.head_dim_qk,
self.head_dim_v,
self.window_size[0],
self.window_size[1],
)
......@@ -150,23 +157,49 @@ class FusedAttnHelper:
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
kv_head_dim = q_head_dim
v_head_dim = q_head_dim
assert nqkv == 3
elif qkv_layout.is_kvpacked():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == v_head_dim
assert nkv == 2
elif qkv_layout.is_separate():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}"
*k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
*v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
assert (
q_head_dim == k_head_dim
), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
assert (
k_max_seqlen == v_max_seqlen
), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
kv_max_seqlen = k_max_seqlen
assert q_batch_shape == k_batch_shape == v_batch_shape, (
f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
)
assert k_num_gqa_groups == v_num_gqa_groups, (
f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
f" {v_num_gqa_groups}"
)
num_gqa_groups = k_num_gqa_groups
else:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert q_aval.dtype == k_aval.dtype == v_aval.dtype
return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)
assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
f" {v_aval.dtype}"
)
return (
q_batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
)
@dataclass(frozen=True)
......@@ -264,15 +297,22 @@ class FusedAttnFwdPrimitive(BasePrimitive):
f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
)
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(
config.is_training,
q_dtype,
k_dtype,
config.qkv_layout,
......@@ -283,7 +323,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups,
q_max_seqlen,
kv_max_seqlen,
head_dim,
q_head_dim,
v_head_dim,
config.window_size,
).get_fused_attn_backend()
......@@ -334,7 +375,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
q_head_dim,
v_head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type.value,
......@@ -386,9 +428,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
......@@ -398,6 +446,13 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if config.cp_striped_window_size is not None:
window_size_left = config.cp_striped_window_size[0]
window_size_right = config.cp_striped_window_size[1]
else:
window_size_left = config.window_size[0]
window_size_right = config.window_size[1]
return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
ctx,
q,
......@@ -420,7 +475,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
qk_head_dim=q_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
......@@ -429,8 +485,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
window_size_left=window_size_left,
window_size_right=window_size_right,
)
@staticmethod
......@@ -698,9 +754,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
......@@ -719,7 +781,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
qk_head_dim,
v_head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type.value,
......@@ -778,9 +841,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
......@@ -790,6 +859,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if config.cp_striped_window_size is not None:
window_size_left = config.cp_striped_window_size[0]
window_size_right = config.cp_striped_window_size[1]
else:
window_size_left = config.window_size[0]
window_size_right = config.window_size[1]
return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
ctx,
q,
......@@ -815,7 +891,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
qk_head_dim=qk_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
......@@ -824,8 +901,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
window_size_left=window_size_left,
window_size_right=window_size_right,
)
@staticmethod
......@@ -1177,6 +1254,7 @@ class _FusedAttnCPWithAllGatherHelper:
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
)
def all_gather_kv(self, k, v):
......@@ -1616,6 +1694,16 @@ class _FusedAttnCPWithP2PHelper:
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
)
# If using scanloop, idx in scan_kv_block() will be a traced device value, but
# _normalize_window_size_for_cp_striped() requires all parameters to be host values
is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1
is_thd_layout = self.config.qkv_layout.is_thd()
is_sliding_window = self.config.window_size[0] != -1
if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop():
raise ValueError(
f"{header} with THD format and sliding window does not support using scan loop"
)
def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
......@@ -1629,6 +1717,7 @@ class _FusedAttnCPWithP2PHelper:
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
)
def stack_kv(self, k, v):
......@@ -2100,6 +2189,67 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedRingAttnBwdPrimitive)
def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size):
"""
Adjust window size with cp_size for striped sharding, where both q_pos and
kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...].
Example 1:
q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0).
q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is:
0 8 16 24 32
----------------
0 | 1 0 0 0 0
8 | 1 1 0 0 0
16 | 0 1 1 0 0
24 | 0 0 1 1 0
32 | 0 0 0 1 1
SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
Adjusted window size = (1, 0).
Example 2:
q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8,
window_size = (15, 0). The effective mask is:
1 9 17 25 33
----------------
0 | 0 0 0 0 0
8 | 1 0 0 0 0
16 | 1 1 0 0 0
24 | 0 1 1 0 0
32 | 0 0 1 1 0
SequenceDescriptor outputs:
q_seqlen = [4, ...], q_seq_offsets = [1, ...],
kv_seqlen = [4, ...], kv_seq_offsets = [0, ...].
If diagonal are all 1, left window size = 2. Now since diagonal are all 0,
we need to use left window size = 2 - 1 = 1 to make cuDNN work.
Example 3:
q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8,
window_size = (22, 0). The effective mask is:
0 8 16 24 32
----------------
7 | 1 0 0 0 0
15 | 1 1 0 0 0
23 | 0 1 1 0 0
31 | 0 0 1 1 0
39 | 0 0 0 1 1
SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
Adjust window size = (1, 0).
"""
left_limit = q_pos0 - window_size[0]
right_limit = q_pos0 + window_size[1]
# Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size
left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1
right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1
left_steps = max(left_steps, 0)
right_steps = max(right_steps, 0)
# If kv_pos0 > q_pos0, we must reduce left window size by 1
shift = 1 if kv_pos0 > q_pos0 else 0
left_steps = left_steps - shift
return left_steps, right_steps
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Striped Ring Attention Forward Primitive
......@@ -2108,9 +2258,6 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
......@@ -2156,6 +2303,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
batch, q_max_seqlen, head, _ = q.shape
......@@ -2176,7 +2324,8 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
def compute(config):
return FusedAttnFwdPrimitive.impl(
q,
kv,
_not_used,
......@@ -2190,9 +2339,22 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
subblock_config,
config,
)
if config.window_size != (-1, -1):
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
cp_rank, kv_src_rank, cp_size, config.window_size
)
current_config = replace(
subblock_config, cp_striped_window_size=cp_striped_window_size
)
else:
current_config = subblock_config
output_per_step, softmax_aux_per_step, _ = compute(current_config)
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
......@@ -2244,9 +2406,6 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
......@@ -2290,13 +2449,15 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
# We need cp_rank to be a host value for adjust_cp_striped_window_size()
cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q)
dkv = jnp.zeros_like(kv)
dbias = jnp.zeros_like(bias)
def scan_kv_block(_idx, carry):
def scan_kv_block(idx, carry):
kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry
# Start communication that feeds the next iteration.
......@@ -2306,7 +2467,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute():
def compute(config):
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv,
......@@ -2324,11 +2485,22 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
config=subblock_config,
config=config,
)
return dq_per_step, dkv_per_step, dbias_per_step
dq_per_step, dkv_per_step, dbias_per_step = compute()
if config.window_size != (-1, -1):
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
cp_rank, kv_src_rank, cp_size, config.window_size
)
current_config = replace(
subblock_config, cp_striped_window_size=cp_striped_window_size
)
else:
current_config = subblock_config
dq_per_step, dkv_per_step, dbias_per_step = compute(current_config)
kv_next, dkv = jnp.unstack(kv_dkv)
dq += dq_per_step
......@@ -2462,6 +2634,7 @@ def fused_attn_fwd(
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
)
primitive = None
......@@ -2583,6 +2756,7 @@ def fused_attn_bwd(
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
)
primitive = None
......
......@@ -33,14 +33,15 @@ class BasePrimitive(metaclass=ABCMeta):
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.name` does not fully match the
A custom call is marked as disabled if the `cls.__name__` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.name) is not None
is_enabled = pattern.fullmatch(cls.__name__) is not None
return is_enabled
@staticmethod
......
......@@ -6,25 +6,31 @@
from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce
import operator
import math
import jax
import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import (
ScaledTensor,
GroupedScaledTensor1x,
ScalingMode,
Quantizer,
GroupedQuantizer,
QuantizeConfig,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
)
__all__ = ["gemm"]
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"]
num_cublas_streams = 4
num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None:
......@@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None:
return 4_194_304
def is_gemm_with_all_layouts_supported() -> False:
"""Return True if using blackwell, False otherwise."""
return get_device_compute_capability(0) >= 100
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
......@@ -41,73 +52,144 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = ()
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
def abstract(
lhs_data_aval,
lhs_scale_inv_aval,
rhs_data_aval,
rhs_scale_inv_aval,
bias_aval,
group_sizes_aval,
group_offset_aval,
*,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
"""
Grouped GEMM operation.
Args:
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
args[ 0 : num_gemms] are the lhs tensors,
args[ num_gemms : 2*num_gemms] are the rhs tensors,
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
num_gemms: Number of GEMM operations to perform.
scaling_mode: Scaling mode for the GEMM operations.
out_dtype: Data type of the output tensors.
has_bias: Boolean indicating if bias tensors are provided.
lhs_data: Left-hand side input matrix data, 1D flattened array
lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array
rhs_data: Right-hand side input matrix data, 1D flattened array
rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array
bias: Bias matrix of shape (G, N)
group_sizes: 1D array containing the sizes of each group
group_offset: 1D array containing offsets for each group (not yet implemented)
M: Number of rows in the output matrix
N: Number of columns in the output matrix
K: Number of columns in the left-hand side matrix
lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
scaling_mode: Scaling mode for the GEMM operations
out_dtype: Data type of the output tensors
has_bias: Boolean indicating if bias tensors are provided
is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation
where both lhs and rhs are 2D matrices and output is (G, M, N)
Returns:
A tuple of ShapedArray objects of size num_gemms+1:
ret[0 : num_gemms]: GEMM output tensors,
ret[num_gemms]:workspace tensor.
A jnp.ndarray containing the result of the grouped GEMM operation
"""
del scaling_mode
expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias
del lhs_scale_inv_aval, rhs_scale_inv_aval
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
# JAX buffer pointers are 128-aligned
# 255 is added to the workspace size to ensure workspace ptr is 256-aligned
workspace_size += 255
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval)
# TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue
out_shape = (M, N)
if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N)
out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
return (out_aval, workspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return out_aval
return (out_aval,)
@staticmethod
def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
def lowering(
ctx,
*args,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
*args,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
M=M,
N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
)
@staticmethod
def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
def impl(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
*args,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
(out, _) = GroupedGemmPrimitive.inner_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M=M,
N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
)
return out[:-1] # out is [out_list, wkspace], only return out_list
return (out,)
register_primitive(GroupedGemmPrimitive)
......@@ -142,65 +224,31 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
def _transpose_contract_dims(ndim, contracting_dims):
return tuple(ndim - i - 1 for i in contracting_dims)[::-1]
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
return out_3d
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
@partial(jax.jit, static_argnums=(2, 3))
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract)
if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
out_fp8 = jax.lax.dot_general(
lhs.data, rhs.data, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
scale_inv = lhs.scale_inv * rhs.scale_inv
out = (out_fp8 * scale_inv).astype(lhs.dq_dtype)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
@partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
......@@ -210,7 +258,6 @@ def _jax_gemm_mxfp8_1d(
assert (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......@@ -241,7 +288,7 @@ def _jax_gemm_mxfp8_1d(
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = scaled_matmul_wrapper(
out_3d = jax.nn.scaled_matmul(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
......@@ -268,9 +315,16 @@ def _jax_gemm(
dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode.is_tensor_scaling():
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
......@@ -313,7 +367,7 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""General matrix multiplication with optional quantization.
......@@ -338,130 +392,190 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
"""
def swizzled_scale(scales):
# Swizzle the scale tensor for FP8 GEMM
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
scales = scales.reshape(rows, cols)
return scales
def grouped_gemm(
lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
group_sizes: jnp.ndarray,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)),
bias: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""
Grouped GEMM operation.
Args:
lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
group_sizes: 1D array containing the sizes of each group
contracting_dims: Tuple of two sequences representing the contracting dimensions
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
def grouped_gemm(
lhs_list: List[Union[jnp.ndarray, ScaledTensor]],
rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
# Grouped GEMM for multiple pairs of tensors.
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
num_gemms = len(lhs_list)
lhs_list_ = []
rhs_list_ = []
lhs_sinv_list_ = []
rhs_sinv_list_ = []
bias_list_ = []
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
contracting_dims = contracting_dims_list[i]
dim_nums = (contracting_dims, ((), ()))
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
scaling_mode = lhs.scaling_mode
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
Returns:
A jnp.ndarray containing the result of the grouped GEMM operation
Note:
Tested shapes:
lhs: [M, K] or [K, N]
rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
"""
# TODO(Phuong): implement the group_offset
group_offset = group_offset or jnp.zeros((1,), jnp.int32)
# TODO(Phuong): implement the precision
del precision
if isinstance(lhs, jnp.ndarray):
assert isinstance(rhs, jnp.ndarray)
out_dtype = lhs.dtype
lhs_shape = lhs.shape
rhs_shape = rhs.shape
lhs_data = lhs
rhs_data = rhs
lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32)
scaling_mode = ScalingMode.NO_SCALING
elif isinstance(lhs, GroupedScaledTensor1x):
assert isinstance(rhs, GroupedScaledTensor1x)
out_dtype = lhs.dq_dtype
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode.is_tensor_scaling():
lhs_shape = lhs.original_shape
rhs_shape = rhs.original_shape
lhs_data = lhs.data
rhs_data = rhs.data
lhs_scale_inv = lhs.scale_inv
rhs_scale_inv = rhs.scale_inv
assert lhs.scaling_mode == rhs.scaling_mode
scaling_mode = lhs.scaling_mode
else:
raise TypeError("Unsupported lhs type object!")
out_dtype = preferred_element_type or out_dtype
lhs_contract_dim, rhs_contract_dim = contracting_dims
lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1
lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1)
# rhs_shape [G, K, N]
rhs_is_trans = rhs_contract_dim[0] != 1
rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim)
is_grouped_dense_wgrad = False
if len(rhs_shape) == 2:
rhs_is_trans = rhs_contract_dim[0] != 0
is_grouped_dense_wgrad = True
# TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this?
if (
is_grouped_dense_wgrad
and not isinstance(lhs, ScaledTensor)
and not isinstance(rhs, ScaledTensor)
):
lhs_is_trans = True
rhs_is_trans = False
lhs_flatten_axis = 1
rhs_flatten_axis = 1
if (
not isinstance(lhs, ScaledTensor)
and not isinstance(rhs, ScaledTensor)
and quantizer_set != noop_quantizer_set
):
assert isinstance(quantizer_set.x, GroupedQuantizer)
assert type(quantizer_set.x) is type(quantizer_set.kernel)
scaling_mode = quantizer_set.x.scaling_mode
if (
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
# scaling_mode.is_tensor_scaling()
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans
quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
)
quantizer_set.kernel.q_layout = (
QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE
)
lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis)
rhs_q = grouped_quantize(
rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis
)
lhs_data = lhs_q.data
rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode.is_tensor_scaling():
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_is_trans = False
rhs_is_trans = True
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
lhs_layout_is_T = lhs.data_layout == "T"
rhs_layout_is_T = rhs.data_layout == "T"
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# x.shape = [D1, D2]
# contracting_dims = (1, ) --> output.shape = [1, D1, D2]
# contracting_dims = (0, ) --> output.shape = [1, D2, D1]
bm = lhs_remain_shape[0]
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(f"m = {bm}, n = {bn}, k = {kl}; ")
print("cuBLAS requires the problem shapes being multiples of 16")
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
lhs_list_.append(lhs_3d)
rhs_list_.append(rhs_3d)
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode.is_tensor_scaling():
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_sinv_list_.append(lhs_scale_inv)
rhs_sinv_list_.append(rhs_scale_inv)
if bias_list is not None:
bias_list_.append(bias_list[i])
out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
*rhs_list_,
*lhs_sinv_list_,
*rhs_sinv_list_,
*bias_list_,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
lhs_layout_is_T = lhs_q.data_layout == "T"
rhs_layout_is_T = rhs_q.data_layout == "T"
lhs_ndim = len(lhs_shape)
rhs_ndim = len(rhs_shape)
if lhs_layout_is_T:
lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
if rhs_layout_is_T:
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T)
rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T)
# Calling GroupedGEMM Custom Call
K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim)
assert K_lhs == K_rhs
M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim))
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G
if is_grouped_dense_wgrad:
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim))
else:
assert group_sizes.size == rhs_shape[0]
assert group_offset.size == 1
has_bias = bias is not None
assert not has_bias or bias.shape == (group_sizes.size, N)
bias = jnp.empty((), jnp.float32) if bias is None else bias
# TODO(Phuong): support MXFP8_1D_SCALING
assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported"
(out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M=M,
N=N,
K=K_lhs,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
has_bias=1 if bias_list is not None else 0,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
)
return out_list
"""
return out
......@@ -183,6 +183,16 @@ def get_xla_flag(flag: str, default=None, cast=str):
return default
def get_min_device_compute_capability():
"""
Returns the minimum compute capability of all local devices.
"""
return min(
transformer_engine_jax.get_device_compute_capability(local_gpu_id)
for local_gpu_id in range(len(jax.local_devices()))
)
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
"""
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
......
......@@ -5,6 +5,7 @@
import operator
from functools import reduce
from typing import Tuple, Optional
import math
from packaging import version
import jax
......@@ -23,12 +24,17 @@ from .misc import (
jax_dtype_to_te_dtype,
multidim_transpose,
should_apply_1x_fused_dbias_war_for_arch_l_100,
get_min_device_compute_capability,
NamedSharding,
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import (
ScaledTensor2x,
ScaledTensor,
ScaledTensorFactory,
GroupedScaledTensor1x,
Quantizer,
GroupedQuantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
......@@ -41,10 +47,10 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias"]
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
class DBiasQuantizePrimitive(BasePrimitive):
class BaseDBiasQuantizePrimitive(BasePrimitive):
"""
Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
"""
......@@ -155,7 +161,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
......@@ -179,7 +185,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
ctx,
x,
scale,
......@@ -205,7 +211,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
te_dbias_quantize_p implementation
"""
del is_outer
assert DBiasQuantizePrimitive.inner_primitive is not None
assert BaseDBiasQuantizePrimitive.inner_primitive is not None
(
out,
colwise_out,
......@@ -214,7 +220,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.inner_primitive.bind(
) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
x,
scale,
out_dtype=out_dtype,
......@@ -262,14 +268,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert DBiasQuantizePrimitive.outer_primitive is not None
assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return (
DBiasQuantizePrimitive.outer_primitive.bind(
BaseDBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
out_dtype=out_dtype,
......@@ -302,7 +308,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
......@@ -314,14 +320,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
)
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -334,15 +340,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
return (
......@@ -374,7 +380,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
......@@ -386,14 +392,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
)
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
desc="BaseDBiasQuantizePrimitive.dbias_sharding",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
......@@ -406,15 +412,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
)
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
......@@ -435,7 +441,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
local_colwise_scale_inv,
local_amax,
local_dbias,
) = DBiasQuantizePrimitive.impl(
) = BaseDBiasQuantizePrimitive.impl(
x,
scale,
out_dtype=out_dtype,
......@@ -485,7 +491,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape),
unique_var="DBiasQuantizePrimitive_i",
unique_var="BaseDBiasQuantizePrimitive_i",
flatten_axis=flatten_axis,
)
......@@ -512,7 +518,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
register_primitive(DBiasQuantizePrimitive)
register_primitive(BaseDBiasQuantizePrimitive)
class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
def _jax_quantize(
......@@ -565,7 +579,8 @@ def _quantize_dbias_impl(
dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled():
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if not PrimitiveClass.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
......@@ -620,6 +635,18 @@ def _quantize_dbias_impl(
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
# It is faster to use 1x quantization for tensor scaling
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x()
and is_1x_kernel_supported
)
q_layout = quantizer.q_layout
if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE
(
rowwise_casted_output,
colwise_casted_output,
......@@ -627,12 +654,12 @@ def _quantize_dbias_impl(
colwise_scale_inv,
updated_amax,
dbias,
) = DBiasQuantizePrimitive.outer_primitive.bind(
) = PrimitiveClass.outer_primitive.bind(
x,
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=quantizer.q_layout.value,
q_layout=q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias,
......@@ -642,6 +669,15 @@ def _quantize_dbias_impl(
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
if q_layout == QuantizeLayout.ROWWISE:
# Quantizer requires 2x quantization, but we are using 1x quantization
# for performance reasons, so we need to generate the colwise data in JAX
if flatten_axis < 0:
flatten_axis += x.ndim
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
)
quantizer.update(updated_amax)
out = ScaledTensorFactory.create(
......@@ -709,3 +745,313 @@ def quantize_dbias(
return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
)
class GroupedQuantizePrimitive(BasePrimitive):
"""
Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
"""
name = "te_grouped_quantize_ffi"
multiple_results = True
impl_static_args = (
3,
4,
5,
6,
7,
8,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
scale_aval,
group_sizes_aval,
*,
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
group_axis,
scale_dtype,
):
"""
te_dbias_quantize_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = math.prod(x_aval.shape)
# TODO(Phuong): can scale_aval be None?
assert scale_aval is None or scale_aval.dtype == jnp.float32
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_grouped_scale_shape_2x(
x_aval.shape,
group_sizes_aval.size,
group_axis,
is_padded=True,
flatten_axis=flatten_axis,
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
rowwise_scale_inv_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_aval = jax.core.ShapedArray(
shape=rowwise_scale_inv_shape, dtype=scale_dtype
)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return (
rowwise_out_aval,
colwise_out_aval,
rowwise_scale_inv_aval,
colwise_scale_inv_aval,
amax_aval,
)
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_quantize_p outer primitive abstract
"""
# Phuong: keeping outer abstract so that we can add fuse dbias later
(
rowwise_out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod
def lowering(
ctx,
x,
scale,
group_sizes,
*,
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
group_axis,
scale_dtype,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype
x_aval, scale_aval, group_sizes_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
assert group_sizes_aval.dtype == jnp.int32
assert group_axis == 0
return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
ctx,
x,
scale,
group_sizes,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
)
@staticmethod
def impl(
x,
scale,
group_sizes,
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
group_axis,
scale_dtype,
):
"""
te_dbias_quantize_p implementation
"""
assert GroupedQuantizePrimitive.inner_primitive is not None
(
rowwise_out,
colwise_out,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
) = GroupedQuantizePrimitive.inner_primitive.bind(
x,
scale,
group_sizes,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
flatten_axis=flatten_axis,
group_axis=group_axis,
scale_dtype=scale_dtype,
)
return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)
register_primitive(GroupedQuantizePrimitive)
def grouped_quantize(
x: jnp.ndarray,
quantizer: GroupedQuantizer,
group_sizes: jnp.ndarray = None,
flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
"""Quantize a tensor in grouped manner.
This function quantizes a tensor by splitting it into groups along a specified axis
and applying quantization to each group separately. The groups can be either specified
explicitly through group_sizes or automatically split along the group_axis.
Args:
x: Input tensor to quantize
quantizer: The quantizer to use for quantization
group_sizes: Array of ints containing the size of each group (default: None)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
A GroupedScaledTensor1x containing the quantized data
Note:
- If group_sizes is not provided, the tensor will be split into equal-sized groups
along the group_axis
- The group_axis is currently fixed to 0
- The quantizer's q_layout determines whether row-wise, column-wise, or both
quantization is applied
"""
if quantizer is None:
return x
# TODO(Phuong): add support for flatten_axis = -2
assert flatten_axis in (
-1,
x.ndim - 1,
), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
group_axis = 0
if group_sizes is None:
group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)
if not GroupedQuantizePrimitive.enabled():
return quantizer.quantize(
x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
)
n_groups = group_sizes.size
original_shape = x.shape
assert n_groups == len(
quantizer.quantizers
), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
scale = jnp.empty((n_groups,), jnp.float32)
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
for i, quantizer_i in enumerate(quantizer.quantizers):
scale = scale.at[i].set(quantizer_i.scale[0])
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
segment_ids = jnp.repeat(
jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
)
# WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
# So we performance ROWWISE_COLWISE and use the colwise_tensor_output
apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
) = GroupedQuantizePrimitive.outer_primitive.bind(
x,
scale,
group_sizes,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
flatten_axis=flatten_axis,
group_axis=group_axis,
scale_dtype=quantizer.get_scale_dtype(),
)
# For DelayedScaling2x and CurrentScaling2x, the scale buffer
# is shared between rowwise and colwise
if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
colwise_scale_inv = rowwise_scale_inv
# TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
for i, quantizer_i in enumerate(quantizer.quantizers):
quantizer_i.update(updated_amax[i].reshape((1,)))
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
return out
def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
"""
Compute the grouped bias gradient.
Args:
grad: jnp.ndarray of shape (M, N)
group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M
Returns:
dbias: jnp.ndarray of shape (num_groups, N)
"""
assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(
jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
)
grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype)
return dbias
......@@ -809,13 +809,7 @@ def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fact
"""
JAX based implementation of scaled and masked softmax
"""
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
return jax.nn.softmax(logits * scale_factor, where=mask != 1)
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
......@@ -823,12 +817,7 @@ def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: fl
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
return jax_scaled_masked_softmax(logits, mask, scale_factor)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
......
......@@ -30,6 +30,7 @@
#include "extensions/misc.h"
#include "extensions/utils.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/multi_stream.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
......@@ -68,6 +69,8 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s
// Quantization
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......@@ -93,25 +96,25 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
......
......@@ -7,7 +7,7 @@
#include <cuda_runtime.h>
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
......
......@@ -4,23 +4,24 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace jax {
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
return backend;
}
......@@ -40,33 +41,46 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape =
std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = dtype;
NVTETensor &softmax_aux = tensor_pack->tensors[0];
NVTEBasicTensor softmax_aux_data;
softmax_aux_data.data_ptr = softmax_buf;
softmax_aux_data.shape.ndim = 4;
softmax_aux_data.shape.data[0] = input_batch;
softmax_aux_data.shape.data[1] = attn_heads;
softmax_aux_data.shape.data[2] = q_max_seqlen;
softmax_aux_data.shape.data[3] = kv_max_seqlen;
softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
tensor_pack->size = 2;
Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
rng_state_aux->data.dptr = rng_state_buf;
rng_state_aux->data.shape = std::vector<size_t>{2};
rng_state_aux->data.dtype = DType::kInt64;
NVTETensor &rng_state_aux = tensor_pack->tensors[1];
NVTEBasicTensor rng_state_aux_data;
rng_state_aux_data.data_ptr = rng_state_buf;
rng_state_aux_data.shape = {};
rng_state_aux_data.shape.ndim = 2;
rng_state_aux_data.dtype = static_cast<NVTEDType>(DType::kInt64);
nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data);
// correct softmax shape/dtype
softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux->data.dtype = DType::kFloat32;
softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
// include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3;
Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
bias_aux->data.dptr = bias_buf;
bias_aux->data.shape =
std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = dtype;
NVTETensor &bias_aux = tensor_pack->tensors[2];
NVTEBasicTensor bias_aux_data;
bias_aux_data.data_ptr = bias_buf;
bias_aux_data.shape.ndim = 4;
bias_aux_data.shape.data[0] = bias_batch;
bias_aux_data.shape.data[1] = bias_heads;
bias_aux_data.shape.data[2] = q_max_seqlen;
bias_aux_data.shape.data[3] = kv_max_seqlen;
bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
}
}
nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
}
/*
......@@ -93,32 +107,34 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
// correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = dtype;
NVTEBasicTensor softmax_aux_data =
nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData);
softmax_aux_data.shape.data[3] = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data);
}
}
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
......@@ -183,6 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
}
}
nvte_tensor_pack_destroy(&aux_output_tensors);
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
}
......@@ -219,17 +237,17 @@ static void FusedAttnForwardImpl(
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
// Memset to 0xF0 for filling large negative numbers
......@@ -239,15 +257,15 @@ static void FusedAttnForwardImpl(
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -260,7 +278,7 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
......@@ -269,8 +287,9 @@ static void FusedAttnForwardImpl(
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
......@@ -281,9 +300,9 @@ static void FusedAttnForwardImpl(
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
......@@ -309,7 +328,8 @@ static void FusedAttnForwardImpl(
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim"); \
size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim"); \
size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
......@@ -344,9 +364,9 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training,
deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check();
}
......@@ -373,33 +393,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
......@@ -469,6 +489,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
}
}
nvte_tensor_pack_destroy(&aux_input_tensors);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
}
......@@ -478,15 +500,15 @@ static void FusedAttnBackwardImpl(
void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets,
void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
......@@ -498,16 +520,16 @@ static void FusedAttnBackwardImpl(
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
/* Call the underly NVTE API */
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
if (is_ragged) {
......@@ -523,8 +545,9 @@ static void FusedAttnBackwardImpl(
bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
......@@ -544,9 +567,9 @@ static void FusedAttnBackwardImpl(
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
......@@ -594,9 +617,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq,
wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype,
wkspace_dtype, is_training, deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check();
}
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
......
......@@ -6,7 +6,7 @@
#include "transformer_engine/cudnn.h"
#include "extensions.h"
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......
......@@ -7,48 +7,130 @@
#include <memory>
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
namespace transformer_engine {
namespace jax {
Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Variadic_Result_Type output_list, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
// Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// A: row-major [m, k] for N - [k, m] for T
// B: row-major [k, n] for N - [n, k] for T
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// A: column-major with size [k, m] for T - [m, k] for N
// B: column-major with size [n, k] for T - [k, n] for N
//
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
int num_streams = nvte_get_num_compute_streams();
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type());
auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr;
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
// Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto workspace_ptr =
reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(workspace->untyped_data()) + 255) &
~static_cast<uintptr_t>(255));
auto workspace_total_size = product(workspace->dimensions()) - 255;
auto workspace_size = workspace_total_size / num_streams;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t expected_lhs_size = m * k;
size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n);
size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n);
size_t actual_lhs_size = product(lhs_data.dimensions());
size_t actual_rhs_size = product(rhs_data.dimensions());
size_t actual_out_size = product(output->dimensions());
NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ",
expected_lhs_size, ", got ", actual_lhs_size);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(expected_rhs_size == actual_rhs_size,
"Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k,
" = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m,
" * ", n, " = ", expected_out_size, ", got ", actual_out_size);
} else {
NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k,
" * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size,
"Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n,
" = ", expected_out_size, ", got ", actual_out_size);
}
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true;
bool trans_rhs = false;
size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms);
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
", got sum(group_sizes)=", sum_group_sizes);
} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
}
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false;
bool accumulate = false;
bool use_split_accumulator = false;
auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
"got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans);
}
// These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list;
......@@ -66,96 +148,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0;
int rhs_list_offset = num_gemms;
int lhs_sinv_list_offset = 2 * num_gemms;
int rhs_sinv_list_offset = 3 * num_gemms;
int bias_list_offset = 4 * num_gemms;
int out_list_offset = 0;
for (int i = 0; i < num_gemms; i++) {
Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k};
auto out_shape = std::vector<size_t>{n, m};
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Note: the scale_inv array should have been swizzled in Python before lowering
auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
for (size_t i = 0; i < num_gemms; i++) {
// Matrix data shapes
size_t m_i = dim_list_host[i];
auto lhs_shape = std::vector<size_t>{m_i, k};
auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n};
auto out_shape = std::vector<size_t>{m_i, n};
if (is_grouped_dense_wgrad) {
size_t k_i = dim_list_host[i];
lhs_shape[0] = lhs_is_trans ? k_i : m;
lhs_shape[1] = lhs_is_trans ? m : k_i;
rhs_shape[0] = rhs_is_trans ? n : k_i;
rhs_shape[1] = rhs_is_trans ? k_i : n;
out_shape[0] = m;
out_shape[1] = n;
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
// Set matrix data pointers
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
void *lhs_vptr = static_cast<void *>(lhs_ptr);
void *rhs_vptr = static_cast<void *>(rhs_ptr);
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape);
else
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape);
else
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else {
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
// Scale_inv shapes
auto lhs_sinv_size = std::vector<size_t>{1};
auto rhs_sinv_size = std::vector<size_t>{1};
if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t scale_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_size[0] = m_i * scale_k;
rhs_sinv_size[0] = n * scale_k;
// Need to add swizzle here
}
auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0};
if (has_bias) {
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
Buffer_Type bias_i = bias_i_get.value();
bias_ptr = bias_i.untyped_data();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
bias_shape[0] = n;
// Set scale_inv pointers
void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr);
void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr);
if (is_fp8_gemm) {
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
else
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
else
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
} else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
}
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype);
// Update pointer for the next GEMM pair
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes;
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes;
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes;
if (is_fp8_gemm) {
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes;
}
if (has_bias) bias_ptr += n * bias_dtype_bytes;
out_wrapper_list.push_back(std::move(out_i_));
// Move objects to the lists to keep them alive
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i));
bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
......@@ -166,10 +235,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
out_list.push_back(out_wrapper_list.back().data());
}
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) {
auto workspace_i =
......@@ -180,7 +245,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad,
pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad,
workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream);
......@@ -190,11 +255,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.RemainingArgs() // input list
.RemainingRets() // output list
.Attr<int64_t>("num_gemms")
.Arg<Buffer_Type>() // lhs_data
.Arg<Buffer_Type>() // lhs_sinv
.Arg<Buffer_Type>() // rhs_data
.Arg<Buffer_Type>() // rhs_sinv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // group_sizes
.Arg<Buffer_Type>() // group_offset
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("M")
.Attr<int64_t>("N")
.Attr<int64_t>("K")
.Attr<bool>("lhs_is_trans")
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"),
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits);
} // namespace jax
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine {
namespace jax {
......@@ -26,5 +26,19 @@ std::vector<size_t> Shape::to_vector() const {
return shape;
}
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) {
auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x;
auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y;
auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x;
auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y;
NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M);
NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N);
size_t scale_x = DIVUP((M / block_x), alignment_x) * alignment_x;
size_t scale_y = DIVUP((N / block_y), alignment_y) * alignment_y;
return {scale_x, scale_y};
}
} // namespace jax
} // namespace transformer_engine
......@@ -67,5 +67,16 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
}
}
constexpr struct BlockSize {
size_t x;
size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
size_t x;
size_t y;
} MXFP8_ALIGNMENT{128, 4};
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
} // namespace jax
} // namespace transformer_engine
......@@ -7,7 +7,7 @@
#include <cuda_runtime.h>
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "../extensions.h"
namespace transformer_engine {
namespace jax {
......@@ -25,6 +25,7 @@ pybind11::dict Registrations() {
// Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
......@@ -68,6 +69,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_num_compute_streams", &nvte_get_num_compute_streams);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes);
m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes);
......
......@@ -5,9 +5,10 @@
************************************************************************/
#include <cuda_runtime.h>
#include "extensions.h"
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......@@ -226,5 +227,182 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
.Ret<Buffer_Type>(), // output
FFI_CudaGraph_Traits);
Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales,
Buffer_Type group_sizes, Result_Type outputs,
Result_Type colwise_outputs, Result_Type scale_invs,
Result_Type colwise_scale_invs, Result_Type amaxs,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
int64_t flatten_axis) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(outputs->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scales.element_type());
auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
auto *output_ptr = reinterpret_cast<uint8_t *>(outputs->untyped_data());
auto *colwise_output_ptr = reinterpret_cast<uint8_t *>(colwise_outputs->untyped_data());
auto *sinv_ptr = reinterpret_cast<uint8_t *>(scale_invs->untyped_data());
auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());
bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
size_t input_dtype_bytes = te_dtype_bytes(in_dtype);
size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0;
size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;
auto input_dims = inputs.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto m = product(input_dims, 0, flatten_axis);
auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m * n};
// These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> input_holders;
std::vector<TensorWrapper> output_holders;
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std::vector<NVTETensor> input_list;
std::vector<NVTETensor> output_list;
size_t num_groups = group_sizes.dimensions()[0];
size_t dim_list_bytes = group_size_dtype_bytes * num_groups;
std::vector<int32_t> dim_list_host(num_groups);
auto *group_size_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes,
"Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m,
input_dims[0]);
if (is_delayed_scaling) {
NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups,
", got ", amaxs->dimensions()[0]);
NVTE_CHECK(amax_dtype == DType::kFloat32 && scale_dtype == DType::kFloat32);
cudaMemsetAsync(amax_ptr, 0, sizeof(float) * num_groups, stream);
}
size_t sinv_size = 0;
size_t colwise_sinv_size = 0;
size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1;
size_t num_non_empty_groups = 0;
for (size_t i = 0; i < num_groups; i++) {
size_t m_i = dim_list_host[i] * non_group_m;
// Skip for zero-size input + shiff the scale ptr
if (m_i == 0) {
if (is_tensor_scaling) scale_ptr += scale_dtype_bytes;
continue;
}
num_non_empty_groups++;
auto shape_i = std::vector<size_t>{m_i, n};
auto shape_trans_i = std::vector<size_t>{n, m_i};
auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (has_rowwise) {
out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
out_i.set_scale(static_cast<void *>(scale_ptr), DType::kFloat32, std::vector<size_t>{1});
out_i.set_amax(static_cast<void *>(amax_ptr), DType::kFloat32, std::vector<size_t>{1});
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype,
std::vector<size_t>{1});
sinv_size = 1;
} else {
const bool is_colwise = false;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
sinv_size = product(sinv_shape_i);
}
}
}
if (has_colwise) {
auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_sinv_ptr = is_tensor_scaling ? sinv_ptr : colwise_sinv_ptr;
if (is_tensor_scaling) {
out_i.set_columnwise_scale_inv(static_cast<void *>(tmp_sinv_ptr), sinv_dtype,
std::vector<size_t>{1});
colwise_sinv_size = 1;
} else {
const bool is_colwise = true;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
sinv_shape_i);
colwise_sinv_size = product(sinv_shape_i);
}
}
input_holders.push_back(std::move(inp_i));
output_holders.push_back(std::move(out_i));
input_list.push_back(input_holders.back().data());
output_list.push_back(output_holders.back().data());
input_ptr += m_i * n * input_dtype_bytes;
scale_ptr += scale_dtype_bytes;
output_ptr += m_i * n * output_dtype_bytes;
colwise_output_ptr += m_i * n * colwise_output_dtype_bytes;
sinv_ptr += sinv_size * sinv_dtype_bytes;
colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes;
amax_ptr += amax_dtype_bytes;
}
QuantizationConfigWrapper quant_config;
nvte_multi_tensor_quantize(input_list.data(), output_list.data(), quant_config,
num_non_empty_groups, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -6,7 +6,7 @@
#include "transformer_engine/softmax.h"
#include "extensions.h"
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment