"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f28877f4db2a136f26c495e033f1d2b4ea1b405c"
Unverified Commit 4d444db1 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Prepare cross flash attention (#525)



* Add rng_state output for cross fused attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add rng_state and output for the flash attention backward
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add bias for the jax cross attn API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a minor bug
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add bias in the backward for the arbitrary fused attn backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 387397a2
...@@ -192,6 +192,7 @@ class TestDistributedCrossAttn: ...@@ -192,6 +192,7 @@ class TestDistributedCrossAttn:
return jnp.mean( return jnp.mean(
cross_fused_attn(q, cross_fused_attn(q,
kv, kv,
None,
mask, mask,
None, None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
......
...@@ -163,7 +163,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -163,7 +163,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
# mask invert # mask invert
mask = (mask == 0) mask = (mask == 0)
return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs) return cross_fused_attn(q, kv, None, mask, dropout_rng, **kwargs)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_CASES)
......
...@@ -1743,7 +1743,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1743,7 +1743,7 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1) softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else: else:
raise ValueError(f'Not supported {backend=}') raise ValueError(f'Unsupported {backend=}')
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
...@@ -1807,15 +1807,11 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive): ...@@ -1807,15 +1807,11 @@ class SelfFusedAttnFwdPrimitive(BasePrimitive):
dropout_probability, is_training): dropout_probability, is_training):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert SelfFusedAttnFwdPrimitive.outer_primitive is not None assert SelfFusedAttnFwdPrimitive.outer_primitive is not None
qkv, bias, cu_seqlen, seed = batched_args
qkv_bdim, _, _, seed_bdim = batch_dims qkv_bdim, _, _, seed_bdim = batch_dims
out_bdims = qkv_bdim, qkv_bdim, seed_bdim out_bdims = qkv_bdim, qkv_bdim, seed_bdim
return SelfFusedAttnFwdPrimitive.outer_primitive.bind( return SelfFusedAttnFwdPrimitive.outer_primitive.bind(
qkv, *batched_args,
bias,
cu_seqlen,
seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -1889,12 +1885,12 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1889,12 +1885,12 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
""" """
name = "te_self_fused_attn_backward" name = "te_self_fused_attn_backward"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9, 10) impl_static_args = (7, 8, 9, 10, 11)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(qkv_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
""" """
...@@ -1902,34 +1898,28 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1902,34 +1898,28 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
""" """
del softmax_aux_aval, rng_state_aval del softmax_aux_aval, rng_state_aval
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del mask_or_cu_seqlen_aval, attn_mask_type del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
del scaling_factor, dropout_probability, is_training del scaling_factor, dropout_probability, is_training
qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
assert qkv_aval.dtype == output_aval.dtype == doutput_aval.dtype bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
*batch_shape, max_seqlen, num_head, _ = output_aval.shape assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_shape = (0,)
else:
bias_shape = (*batch_shape[:-1], 1, num_head, max_seqlen, max_seqlen)
bias_dtype = qkv_dtype
dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
dbias = qkv_aval.update(shape=bias_shape, dtype=bias_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
return dqkv_aval, dbias return dqkv_aval, dbias_aval
@staticmethod @staticmethod
def lowering(ctx, qkv, softmax_aux, rng_state, output, doutput, cu_seqlen, *, attn_bias_type, def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Self fused attention bwd lowering rules Self fused attention bwd lowering rules
""" """
qkv_aval, _, _, _, _, _ = ctx.avals_in qkv_aval, _, _, _, _, _, _ = ctx.avals_in
*batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
operands = [qkv, softmax_aux, rng_state, output, doutput, cu_seqlen] operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
...@@ -1947,7 +1937,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1947,7 +1937,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(qkv, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type, def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
assert SelfFusedAttnBwdPrimitive.inner_primitive is not None assert SelfFusedAttnBwdPrimitive.inner_primitive is not None
...@@ -1955,6 +1945,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1955,6 +1945,7 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind( dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
qkv, qkv,
bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1972,17 +1963,11 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1972,17 +1963,11 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
dropout_probability, is_training): dropout_probability, is_training):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert SelfFusedAttnBwdPrimitive.outer_primitive is not None assert SelfFusedAttnBwdPrimitive.outer_primitive is not None
qkv, softmax_aux, rng_state, output, doutput, cu_seqlen = batched_args
qkv_bdim, *_ = batch_dims qkv_bdim, *_ = batch_dims
out_bdims = qkv_bdim, qkv_bdim out_bdims = qkv_bdim, qkv_bdim
return SelfFusedAttnBwdPrimitive.outer_primitive.bind( return SelfFusedAttnBwdPrimitive.outer_primitive.bind(
qkv, *batched_args,
softmax_aux,
rng_state,
output,
doutput,
cu_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -1993,14 +1978,12 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -1993,14 +1978,12 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training, mesh, arg_infos, dropout_probability, is_training, mesh, arg_infos,
result_infos): result_infos):
del attn_mask_type, scaling_factor, dropout_probability, del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
del is_training, result_infos del is_training, result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
bias_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
dbias_spec = [None] dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
dbias_spec = [*x_spec[:-5], None, x_spec[-2], None, None]
dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_spec))
return (dx_sharding, dbias_sharding) return (dx_sharding, dbias_sharding)
@staticmethod @staticmethod
...@@ -2008,17 +1991,16 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2008,17 +1991,16 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
mesh, arg_infos, result_infos): mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
bias_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
dbias_spec = [None] dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
dbias_spec = [*x_spec[:-5], None, x_spec[-2], None, None]
dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dx_sharding, dbias_sharding) out_shardings = (dx_sharding, dbias_sharding)
def sharded_impl(qkv, softmax_aux, rng_state, output, doutput, cu_seqlen): def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen):
local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl( local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl(
qkv, qkv,
bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2040,15 +2022,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2040,15 +2022,20 @@ class SelfFusedAttnBwdPrimitive(BasePrimitive):
register_primitive(SelfFusedAttnBwdPrimitive) register_primitive(SelfFusedAttnBwdPrimitive)
def self_fused_attn_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
output: jnp.ndarray, doutput: jnp.ndarray, squeezed_mask: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention bwd Wrapper for TE self fused attention bwd
Return the gradients of self fused attention with packed qkv input Return the gradients of self fused attention with packed qkv input
""" """
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv, return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2067,18 +2054,19 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2067,18 +2054,19 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
""" """
name = "te_cross_fused_attn_forward" name = "te_cross_fused_attn_forward"
multiple_results = True multiple_results = True
impl_static_args = (5, 6, 7, 8, 9) impl_static_args = (6, 7, 8, 9, 10)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(q_aval, kv_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval, seed_aval, *, def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval,
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training):
""" """
Cross fused attention fwd abstract Cross fused attention fwd abstract
""" """
del seed_aval, attn_bias_type, attn_mask_type # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
del scaling_factor, dropout_probability, is_training del scaling_factor, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
*q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
...@@ -2086,37 +2074,57 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2086,37 +2074,57 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
*kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
assert q_dtype == kv_dtype bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == kv_dtype == bias_dtype
assert q_batch_shape == kv_batch_shape assert q_batch_shape == kv_batch_shape
assert q_num_head == kv_num_head assert q_num_head == kv_num_head
assert q_head_dim == kv_head_dim assert q_head_dim == kv_head_dim
assert nkv == 2 assert nkv == 2
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype
output_shape = q_aval.shape output_shape = q_aval.shape
output_dtype = q_dtype output_dtype = q_dtype
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
softmax_aux_dtype = q_dtype backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
attn_bias_type, attn_mask_type, dropout_probability, q_max_seqlen,
kv_max_seqlen, q_head_dim).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
softmax_aux_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1)
softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f'Unsupported {backend=}')
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_dtype = seed_dtype
out_aval = q_aval.update(shape=output_shape, dtype=output_dtype) out_aval = q_aval.update(shape=output_shape, dtype=output_dtype)
softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype) softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
return out_aval, softmax_aux_aval rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod @staticmethod
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type, def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
scaling_factor, dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention fwd lowering rules Cross fused attention fwd lowering rules
""" """
q_aval, kv_aval, _, _, _ = ctx.avals_in q_aval, kv_aval, *_ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype assert q_aval.dtype == kv_aval.dtype
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4] kv_max_seqlen = kv_aval.shape[-4]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed] operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
...@@ -2134,16 +2142,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2134,16 +2142,17 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(q, kv, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type, def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
assert CrossFusedAttnFwdPrimitive.inner_primitive is not None assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
output, softmax_aux = CrossFusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
q, q,
kv, kv,
bias,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
seed, seed,
...@@ -2152,23 +2161,18 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2152,23 +2161,18 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, softmax_aux return output, softmax_aux, rng_state
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert CrossFusedAttnFwdPrimitive.outer_primitive is not None assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
q, kv, q_cu_seqlen, kv_cu_seqlen, seed = batched_args q_bdim, *_, seed_bdim = batch_dims
q_bdim, *_ = batch_dims
out_bdims = q_bdim, q_bdim out_bdims = q_bdim, q_bdim, seed_bdim
return CrossFusedAttnFwdPrimitive.outer_primitive.bind( return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
q, *batched_args,
kv,
q_cu_seqlen,
kv_cu_seqlen,
seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -2186,7 +2190,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2186,7 +2190,8 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding( softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
return (out_sharding, softmax_aux_sharding) rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod @staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
...@@ -2197,9 +2202,10 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2197,9 +2202,10 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding( softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
seed_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) rng_state_sharding = seed_sharding = NamedSharding(mesh,
PartitionSpec(get_all_mesh_axes(), None))
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(CrossFusedAttnFwdPrimitive.impl, impl = partial(CrossFusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -2212,10 +2218,11 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive): ...@@ -2212,10 +2218,11 @@ class CrossFusedAttnFwdPrimitive(BasePrimitive):
register_primitive(CrossFusedAttnFwdPrimitive) register_primitive(CrossFusedAttnFwdPrimitive)
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_squeezed_mask: jnp.ndarray, def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
kv_squeezed_mask: jnp.ndarray, seed: jnp.ndarray, q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
scaling_factor: float, dropout_probability: float, is_training: bool): attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
""" """
Wrapper for TE cross fused attention fwd Wrapper for TE cross fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
...@@ -2223,8 +2230,13 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_squeezed_mask: jnp.n ...@@ -2223,8 +2230,13 @@ def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_squeezed_mask: jnp.n
checker = _FusedAttnRNGStateChecker() checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training) seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q, return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
kv, kv,
bias,
q_squeezed_mask, q_squeezed_mask,
kv_squeezed_mask, kv_squeezed_mask,
seed, seed,
...@@ -2241,45 +2253,46 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2241,45 +2253,46 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
""" """
name = "te_cross_fused_attn_backward" name = "te_cross_fused_attn_backward"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9, 10) impl_static_args = (9, 10, 11, 12, 13)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(q_aval, kv_aval, softmax_aux_aval, doutput_aval, q_cu_seqlen_aval, def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
kv_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
dropout_probability, is_training): attn_mask_type, scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention bwd abstract Cross fused attention bwd abstract
""" """
del attn_bias_type, attn_mask_type del softmax_aux_aval, rng_state_aval, output_aval
del scaling_factor, dropout_probability, is_training del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
softmax_aux_dtype = dtypes.canonicalize_dtype(softmax_aux_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == kv_dtype == softmax_aux_dtype == doutput_dtype assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
# outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype) dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
return dq_aval, dkv_aval dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
return dq_aval, dkv_aval, dbias_aval
@staticmethod @staticmethod
def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type, def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
attn_mask_type, scaling_factor, dropout_probability, is_training): kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
""" """
Cross fused attention bwd lowering rules Cross fused attention bwd lowering rules
""" """
q_aval, kv_aval, _, _, _, _ = ctx.avals_in q_aval, kv_aval, *_ = ctx.avals_in
assert q_aval.dtype == kv_aval.dtype assert q_aval.dtype == kv_aval.dtype
*batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
batch = reduce(operator.mul, batch_shape) batch = reduce(operator.mul, batch_shape)
kv_max_seqlen = kv_aval.shape[-4] kv_max_seqlen = kv_aval.shape[-4]
operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen] operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
...@@ -2300,17 +2313,21 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2300,17 +2313,21 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
return out return out
@staticmethod @staticmethod
def impl(q, kv, softmax_aux, doutput, q_squeezed_mask, kv_squeezed_mask, attn_bias_type, def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask,
attn_mask_type, scaling_factor, dropout_probability, is_training): kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training):
assert CrossFusedAttnBwdPrimitive.inner_primitive is not None assert CrossFusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask) q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask) kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
dq, dkv = CrossFusedAttnBwdPrimitive.inner_primitive.bind( dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
kv, kv,
bias,
softmax_aux, softmax_aux,
rng_state,
output,
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -2319,24 +2336,18 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2319,24 +2336,18 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return dq, dkv return dq, dkv, dbias
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training): dropout_probability, is_training):
_check_valid_batch_dims(batch_dims) _check_valid_batch_dims(batch_dims)
assert CrossFusedAttnBwdPrimitive.outer_primitive is not None assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen = batched_args
q_bdim, kv_bdim, *_ = batch_dims q_bdim, kv_bdim, *_ = batch_dims
out_bdims = q_bdim, kv_bdim out_bdims = q_bdim, kv_bdim, q_bdim
return CrossFusedAttnBwdPrimitive.outer_primitive.bind( return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
q, *batched_args,
kv,
softmax_aux,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -2351,9 +2362,11 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2351,9 +2362,11 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
del dropout_probability, is_training, result_infos del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
kv_spec = get_padded_spec(arg_infos[1]) kv_spec = get_padded_spec(arg_infos[1])
bias_spec = get_padded_spec(arg_infos[2])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
return (dq_sharding, dkv_sharding) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dkv_sharding, dbias_sharding)
@staticmethod @staticmethod
def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
...@@ -2361,25 +2374,43 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive): ...@@ -2361,25 +2374,43 @@ class CrossFusedAttnBwdPrimitive(BasePrimitive):
del result_infos del result_infos
q_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
kv_spec = get_padded_spec(arg_infos[1]) kv_spec = get_padded_spec(arg_infos[1])
bias_spec = get_padded_spec(arg_infos[2])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dkv_sharding) out_shardings = (dq_sharding, dkv_sharding, dbias_sharding)
impl = partial(CrossFusedAttnBwdPrimitive.impl, def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
attn_bias_type=attn_bias_type, kv_cu_seqlen):
attn_mask_type=attn_mask_type, local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl(
scaling_factor=scaling_factor, q,
dropout_probability=dropout_probability, kv,
is_training=is_training) bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
return local_dq, local_dkv, global_dbias
return mesh, impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CrossFusedAttnBwdPrimitive) register_primitive(CrossFusedAttnBwdPrimitive)
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray, def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray, doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
...@@ -2388,9 +2419,15 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarr ...@@ -2388,9 +2419,15 @@ def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarr
Wrapper for TE cross fused attention bwd Wrapper for TE cross fused attention bwd
Return the gradients of cross fused attention with packed kv input Return the gradients of cross fused attention with packed kv input
""" """
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q, return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
kv, kv,
bias,
softmax_aux, softmax_aux,
rng_state,
output,
doutput, doutput,
q_squeezed_mask, q_squeezed_mask,
kv_squeezed_mask, kv_squeezed_mask,
......
...@@ -837,15 +837,16 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -837,15 +837,16 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
// input // input
void *qkv = buffers[0]; void *qkv = buffers[0];
void *softmax_aux = buffers[1]; void *bias = buffers[1];
void *rng_state = buffers[2]; void *softmax_aux = buffers[2];
void *output = buffers[3]; void *rng_state = buffers[3];
void *doutput = buffers[4]; void *output = buffers[4];
void *cu_seqlens = buffers[5]; void *doutput = buffers[5];
void *cu_seqlens = buffers[6];
// output // output
void *dqkv = buffers[6]; void *dqkv = buffers[7];
void *dbias = buffers[7]; void *dbias = buffers[8];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
...@@ -881,13 +882,15 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -881,13 +882,15 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 2; aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]); auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2}; rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64; rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state; rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
...@@ -923,13 +926,15 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -923,13 +926,15 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
// input // input
void *q = buffers[0]; void *q = buffers[0];
void *kv = buffers[1]; void *kv = buffers[1];
void *q_cu_seqlens = buffers[2]; void *bias = buffers[2];
void *kv_cu_seqlens = buffers[3]; void *q_cu_seqlens = buffers[3];
void *seed = buffers[4]; void *kv_cu_seqlens = buffers[4];
void *seed = buffers[5];
// output // output
void *output = buffers[5]; void *output = buffers[6];
void *softmax_aux = buffers[6]; void *softmax_aux = buffers[7];
void *rng_state = buffers[8];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
...@@ -946,23 +951,32 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -946,23 +951,32 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim}; auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
// TODO(rewang): add bias for cross attn? auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); // output tensors
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
// aux tensors
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -972,30 +986,18 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq ...@@ -972,30 +986,18 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream); query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
// fused attn workspace + workspace for rng_state auto workspace_size = query_workspace_tensor.shape().data[0];
auto plan_workspace_size = auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = WorkspaceManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
...@@ -1014,21 +1016,28 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1014,21 +1016,28 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
// input // input
void *q = buffers[0]; void *q = buffers[0];
void *kv = buffers[1]; void *kv = buffers[1];
void *softmax_aux = buffers[2]; void *bias = buffers[2];
void *doutput = buffers[3]; void *softmax_aux = buffers[3];
void *q_cu_seqlens = buffers[4]; void *rng_state = buffers[4];
void *kv_cu_seqlens = buffers[5]; void *output = buffers[5];
void *doutput = buffers[6];
void *q_cu_seqlens = buffers[7];
void *kv_cu_seqlens = buffers[8];
// output // output
void *dq = buffers[6]; void *dq = buffers[9];
void *dkv = buffers[7]; void *dkv = buffers[10];
void *dp = softmax_aux; void *dbias = buffers[11];
auto batch = descriptor.batch; auto batch = descriptor.batch;
auto num_head = descriptor.num_head; auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen; auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen; auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim; auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}; auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
...@@ -1038,33 +1047,33 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1038,33 +1047,33 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output // F16 doesn't use this tensor
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
// TODO(rewang): generalize cross attn auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO(rewang): need to think about how to pass aux_output_tensors // TODO(rewang): need to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1; aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch * num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
...@@ -1074,11 +1083,10 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1074,11 +1083,10 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
descriptor.bias_type, descriptor.mask_type, query_workspace_tensor.data(), stream); mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size = size_t workspace_size = query_workspace_tensor.shape().data[0];
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size); auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor = auto workspace_tensor =
...@@ -1090,8 +1098,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa ...@@ -1090,8 +1098,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16 s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
descriptor.bias_type, descriptor.mask_type, workspace_tensor.data(), stream); mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
......
...@@ -667,6 +667,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -667,6 +667,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
x = cross_fused_attn(query, x = cross_fused_attn(query,
kv_proj, kv_proj,
bias,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
......
...@@ -81,7 +81,7 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -81,7 +81,7 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
squeezed_mask = mask[:, :, :, 0] squeezed_mask = mask[..., 0]
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias, bias,
squeezed_mask, squeezed_mask,
...@@ -91,14 +91,15 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda ...@@ -91,14 +91,15 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training) is_training=is_training)
return output, (qkv, softmax_aux, rng_state, output, squeezed_mask) return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz): is_training, ctx, dz):
qkv, softmax_aux, rng_state, output, squeezed_mask = ctx qkv, bias, softmax_aux, rng_state, output, squeezed_mask = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv, grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -119,8 +120,8 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr ...@@ -119,8 +120,8 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule) _self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
""" """
Cross multi-head attention wrapper Cross multi-head attention wrapper
...@@ -128,6 +129,7 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: j ...@@ -128,6 +129,7 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: j
output = _cross_fused_attn(q, output = _cross_fused_attn(q,
kv, kv,
bias,
mask, mask,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
...@@ -139,52 +141,60 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: j ...@@ -139,52 +141,60 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: j
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray, def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool): scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training) scaling_factor, dropout_probability, is_training)
return output return output
def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
q_squeezed_mask = mask[:, :, :, 0] q_squeezed_mask = mask[..., 0]
kv_squeezed_mask = mask[:, :, 0, :] kv_squeezed_mask = mask[..., 0, :]
output, softmax_aux = cross_fused_attn_fwd(q, output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
kv, kv,
q_squeezed_mask, bias,
kv_squeezed_mask, q_squeezed_mask,
seed, kv_squeezed_mask,
attn_bias_type=attn_bias_type.value, seed,
attn_mask_type=attn_mask_type.value, attn_bias_type=attn_bias_type.value,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type.value,
dropout_probability=dropout_probability, scaling_factor=scaling_factor,
is_training=is_training) dropout_probability=dropout_probability,
return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask) is_training=is_training)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz): is_training, ctx, dz):
softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv = cross_fused_attn_bwd(q, grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
kv, kv,
softmax_aux, bias,
dz, softmax_aux,
q_squeezed_mask, rng_state,
kv_squeezed_mask, output,
attn_bias_type=attn_bias_type.value, dz,
attn_mask_type=attn_mask_type.value, q_squeezed_mask,
scaling_factor=scaling_factor, kv_squeezed_mask,
dropout_probability=dropout_probability, attn_bias_type=attn_bias_type.value,
is_training=is_training) attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
return grad_q, grad_kv, None, None dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_kv, grad_bias, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) _cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment