"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "46d4eecf2e20ed970fa4f1dbfcf6b146c19a7597"
Unverified Commit 4d444db1 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Prepare cross flash attention (#525)



* Add rng_state output for cross fused attention
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add rng_state and output for the flash attention backward
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add bias for the jax cross attn API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix a minor bug
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add bias in the backward for the arbitrary fused attn backend
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 387397a2
......@@ -192,6 +192,7 @@ class TestDistributedCrossAttn:
return jnp.mean(
cross_fused_attn(q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
......
......@@ -163,7 +163,7 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
# mask invert
mask = (mask == 0)
return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs)
return cross_fused_attn(q, kv, None, mask, dropout_rng, **kwargs)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
......
This diff is collapsed.
......@@ -837,15 +837,16 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
// input
void *qkv = buffers[0];
void *softmax_aux = buffers[1];
void *rng_state = buffers[2];
void *output = buffers[3];
void *doutput = buffers[4];
void *cu_seqlens = buffers[5];
void *bias = buffers[1];
void *softmax_aux = buffers[2];
void *rng_state = buffers[3];
void *output = buffers[4];
void *doutput = buffers[5];
void *cu_seqlens = buffers[6];
// output
void *dqkv = buffers[6];
void *dbias = buffers[7];
void *dqkv = buffers[7];
void *dbias = buffers[8];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
......@@ -881,13 +882,15 @@ void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaq
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 2;
aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor;
......@@ -923,13 +926,15 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
// input
void *q = buffers[0];
void *kv = buffers[1];
void *q_cu_seqlens = buffers[2];
void *kv_cu_seqlens = buffers[3];
void *seed = buffers[4];
void *bias = buffers[2];
void *q_cu_seqlens = buffers[3];
void *kv_cu_seqlens = buffers[4];
void *seed = buffers[5];
// output
void *output = buffers[5];
void *softmax_aux = buffers[6];
void *output = buffers[6];
void *softmax_aux = buffers[7];
void *rng_state = buffers[8];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
......@@ -946,23 +951,32 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
auto kv_shape = std::vector<size_t>{batch * kv_max_seqlen, 2, num_head, head_dim};
auto bias_shape = std::vector<size_t>{1, num_head, q_max_seqlen, kv_max_seqlen};
// input tensors
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
// TODO(rewang): add bias for cross attn?
auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// FP16/BF16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
// output tensors
auto o_tensor =
TensorWrapper(output, std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim}, dtype);
// aux tensors
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
......@@ -972,30 +986,18 @@ void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaq
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
// fused attn workspace + workspace for rng_state
auto plan_workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = WorkspaceManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_max_seqlen, kv_max_seqlen, head_dim);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
......@@ -1014,21 +1016,28 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
// input
void *q = buffers[0];
void *kv = buffers[1];
void *softmax_aux = buffers[2];
void *doutput = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *bias = buffers[2];
void *softmax_aux = buffers[3];
void *rng_state = buffers[4];
void *output = buffers[5];
void *doutput = buffers[6];
void *q_cu_seqlens = buffers[7];
void *kv_cu_seqlens = buffers[8];
// output
void *dq = buffers[6];
void *dkv = buffers[7];
void *dp = softmax_aux;
void *dq = buffers[9];
void *dkv = buffers[10];
void *dbias = buffers[11];
auto batch = descriptor.batch;
auto num_head = descriptor.num_head;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto head_dim = descriptor.head_dim;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD;
auto dtype = descriptor.dtype;
auto q_shape = std::vector<size_t>{batch * q_max_seqlen, num_head, head_dim};
......@@ -1038,33 +1047,33 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
// It's a little trick that the flash attn needs fwd output
// But when seqlen <= 512, it is not needed
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// FP16/BF16 doesn't use this tensor
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
// TODO(rewang): generalize cross attn
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
// Currently, no rng_state required for bwd
auto rng_state = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt64);
// TODO(rewang): need to think about how to pass aux_output_tensors
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
aux_output_tensors.size = 1;
aux_output_tensors.size = 3;
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.shape = std::vector<size_t>{batch * num_head, q_max_seqlen, kv_max_seqlen};
output_s->data.dptr = softmax_aux;
auto *rng_state_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[1]);
rng_state_tensor->data.shape = std::vector<size_t>{2};
rng_state_tensor->data.dtype = DType::kInt64;
rng_state_tensor->data.dptr = rng_state;
auto *bias_tensor = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[2]);
bias_tensor->data = SimpleTensor(bias, bias_shape, dtype);
TensorWrapper query_workspace_tensor;
......@@ -1074,11 +1083,10 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD,
descriptor.bias_type, descriptor.mask_type, query_workspace_tensor.data(), stream);
descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
mask_type, query_workspace_tensor.data(), stream);
size_t workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
size_t workspace_size = query_workspace_tensor.shape().data[0];
auto *workspace = WorkspaceManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
......@@ -1090,8 +1098,8 @@ void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opa
s_tensor.data(), // not used for FP16/BF16
&aux_output_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD,
descriptor.bias_type, descriptor.mask_type, workspace_tensor.data(), stream);
descriptor.scaling_factor, dropout_probability, NVTE_QKV_Layout::NVTE_BSHD_BS2HD, bias_type,
mask_type, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
......
......@@ -667,6 +667,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
x = cross_fused_attn(query,
kv_proj,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
......
......@@ -81,7 +81,7 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
squeezed_mask = mask[:, :, :, 0]
squeezed_mask = mask[..., 0]
output, softmax_aux, rng_state = self_fused_attn_fwd(qkv,
bias,
squeezed_mask,
......@@ -91,14 +91,15 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (qkv, softmax_aux, rng_state, output, squeezed_mask)
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask)
def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
qkv, softmax_aux, rng_state, output, squeezed_mask = ctx
qkv, bias, softmax_aux, rng_state, output, squeezed_mask = ctx
grad_qkv, grad_bias = self_fused_attn_bwd(qkv,
bias,
softmax_aux,
rng_state,
output,
......@@ -119,8 +120,8 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr
_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule)
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
"""
Cross multi-head attention wrapper
......@@ -128,6 +129,7 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: j
output = _cross_fused_attn(q,
kv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
......@@ -139,52 +141,60 @@ def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: j
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type,
output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training)
return output
def _cross_fused_attn_fwd_rule(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
q_squeezed_mask = mask[:, :, :, 0]
kv_squeezed_mask = mask[:, :, 0, :]
q_squeezed_mask = mask[..., 0]
kv_squeezed_mask = mask[..., 0, :]
output, softmax_aux = cross_fused_attn_fwd(q,
kv,
q_squeezed_mask,
kv_squeezed_mask,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask)
output, softmax_aux, rng_state = cross_fused_attn_fwd(q,
kv,
bias,
q_squeezed_mask,
kv_squeezed_mask,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return output, (q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask)
def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
is_training, ctx, dz):
softmax_aux, q, kv, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv = cross_fused_attn_bwd(q,
kv,
softmax_aux,
dz,
q_squeezed_mask,
kv_squeezed_mask,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
return grad_q, grad_kv, None, None
q, kv, bias, softmax_aux, rng_state, output, q_squeezed_mask, kv_squeezed_mask = ctx
grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q,
kv,
bias,
softmax_aux,
rng_state,
output,
dz,
q_squeezed_mask,
kv_squeezed_mask,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_kv, grad_bias, None, None
_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment