Unverified Commit b88f727b authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[JAX] Make all jax attention calls use non-packed common calls (#2358)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* add notes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 262c184e
......@@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so
return view;
}
// Helper function to calculate stride for packed QKV tensor unpacking
// Helper function to calculate stride in bytes for packed QKV tensor unpacking
size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype,
size_t h, size_t d) {
size_t stride = 0;
......
......@@ -123,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
......@@ -156,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
......@@ -174,37 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
}
nvte_tensor_pack_destroy(&aux_output_tensors);
......@@ -291,36 +258,49 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false,
false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
// Prepare Q, K, V pointers and shapes based on layout
// Python passes dummy tensors for unused slots, so we extract from the actual packed data
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
// QKV packed in q: [batch*seqlen, 3, heads, dim]
// Python passes: q=packed_qkv, k=dummy, v=dummy
// Extract K and V pointers from the packed q data
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
NVTE_CHECK(qk_head_dim == v_head_dim,
"For QKV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
// For packed QKV, all have same shape since they're views into the same packed tensor
k_shape = q_shape;
v_shape = q_shape;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
// Python passes: q=query, k=packed_kv, v=dummy
// Extract V pointer from the packed k data
NVTE_CHECK(qk_head_dim == v_head_dim,
"For KV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
q_ptr = q;
k_ptr = k;
v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
// V has same shape as K since they're packed together
v_shape = k_shape;
}
// else NVTE_HD_HD_HD: pointers and shapes already correct
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
......@@ -329,9 +309,6 @@ static void FusedAttnForwardImpl(
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_output_tensors);
}
......@@ -414,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
......@@ -450,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
......@@ -471,28 +436,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, false, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......@@ -504,9 +448,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
}
nvte_tensor_pack_destroy(&aux_input_tensors);
......@@ -552,75 +493,81 @@ static void FusedAttnBackwardImpl(
softmax_aux, rng_state, bias);
/* Call the underly NVTE API */
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype),
stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, deterministic,
false, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype),
stream);
}
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic, false,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
// Prepare Q, K, V pointers and shapes based on layout
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
void *dq_ptr = dq;
void *dk_ptr = dk;
void *dv_ptr = dv;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
// QKV packed in q: [batch*seqlen, 3, heads, dim]
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
NVTE_CHECK(qk_head_dim == v_head_dim,
"For QKV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
dq_ptr = dq;
dk_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + stride);
dv_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + 2 * stride);
k_shape = q_shape;
v_shape = q_shape;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
NVTE_CHECK(qk_head_dim == v_head_dim,
"For KV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
q_ptr = q;
k_ptr = k;
v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
dq_ptr = dq;
dk_ptr = dk;
dv_ptr = static_cast<void *>(static_cast<int8_t *>(dk) + stride);
// V has same shape as K since they're packed together
v_shape = k_shape;
}
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype);
auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype);
if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
size_t dtype_size = typeToSize(dtype);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
// For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once
cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
// Clear dq
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
// For packed KV, dk contains both dk and dv - clear all at once
cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream);
} else {
// All separate - clear each individually
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
}
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
false, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment