"include/vscode:/vscode.git/clone" did not exist on "38513a8bb154f0b6d0a4088814fe92552696d465"
Unverified Commit b88f727b authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[JAX] Make all jax attention calls use non-packed common calls (#2358)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* add notes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 262c184e
...@@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so ...@@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so
return view; return view;
} }
// Helper function to calculate stride for packed QKV tensor unpacking // Helper function to calculate stride in bytes for packed QKV tensor unpacking
size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype,
size_t h, size_t d) { size_t h, size_t d) {
size_t stride = 0; size_t stride = 0;
......
...@@ -123,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -123,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
...@@ -156,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -156,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
...@@ -174,37 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -174,37 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto ragged_offset_tensor = auto ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { nvte_fused_attn_fwd(
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
nvte_fused_attn_fwd_qkvpacked( dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
} }
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
...@@ -291,47 +258,57 @@ static void FusedAttnForwardImpl( ...@@ -291,47 +258,57 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
// Prepare Q, K, V pointers and shapes based on layout
// Python passes dummy tensors for unused slots, so we extract from the actual packed data
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; // QKV packed in q: [batch*seqlen, 3, heads, dim]
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); // Python passes: q=packed_qkv, k=dummy, v=dummy
nvte_fused_attn_fwd_qkvpacked( // Extract K and V pointers from the packed q data
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), NVTE_CHECK(qk_head_dim == v_head_dim,
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, "For QKV packed layout, qk_head_dim must equal v_head_dim");
false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
window_size_left, window_size_right, workspace_tensor.data(), stream); q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
// For packed QKV, all have same shape since they're views into the same packed tensor
k_shape = q_shape;
v_shape = q_shape;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
auto kv_shape = // Python passes: q=query, k=packed_kv, v=dummy
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; // Extract V pointer from the packed k data
auto q_tensor = TensorWrapper(q, q_shape, dtype); NVTE_CHECK(qk_head_dim == v_head_dim,
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); "For KV packed layout, qk_head_dim must equal v_head_dim");
nvte_fused_attn_fwd_kvpacked( size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), q_ptr = q;
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), k_ptr = k;
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), // V has same shape as K since they're packed together
q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, v_shape = k_shape;
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
} }
// else NVTE_HD_HD_HD: pointers and shapes already correct
auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
...@@ -414,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -414,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
...@@ -450,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -450,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
...@@ -471,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -471,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
s_tensor.data(), // not used for F16 doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), s_tensor.data(), // not used for F16
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
deterministic, false, query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
nvte_fused_attn_bwd_kvpacked( window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
} }
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
...@@ -552,76 +493,82 @@ static void FusedAttnBackwardImpl( ...@@ -552,76 +493,82 @@ static void FusedAttnBackwardImpl(
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout
void *q_ptr = q;
void *k_ptr = k;
void *v_ptr = v;
void *dq_ptr = dq;
void *dk_ptr = dk;
void *dv_ptr = dv;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; // QKV packed in q: [batch*seqlen, 3, heads, dim]
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen");
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); NVTE_CHECK(qk_head_dim == v_head_dim,
if (is_ragged) { "For QKV packed layout, qk_head_dim must equal v_head_dim");
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
stream); q_ptr = q;
} k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), v_ptr = static_cast<void *>(static_cast<int8_t *>(q) + 2 * stride);
s_tensor.data(), // not used for F16 dq_ptr = dq;
s_tensor.data(), // not used for F16 dk_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + stride);
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), dv_ptr = static_cast<void *>(static_cast<int8_t *>(dq) + 2 * stride);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), k_shape = q_shape;
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, v_shape = q_shape;
dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, deterministic,
false, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
auto kv_shape = NVTE_CHECK(qk_head_dim == v_head_dim,
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; "For KV packed layout, qk_head_dim must equal v_head_dim");
auto q_tensor = TensorWrapper(q, q_shape, dtype); size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); q_ptr = q;
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); k_ptr = k;
auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); v_ptr = static_cast<void *>(static_cast<int8_t *>(k) + stride);
if (is_ragged) { dq_ptr = dq;
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); dk_ptr = dk;
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), dv_ptr = static_cast<void *>(static_cast<int8_t *>(dk) + stride);
stream); // V has same shape as K since they're packed together
} v_shape = k_shape;
nvte_fused_attn_bwd_kvpacked( }
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
s_tensor.data(), // not used for F16 auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype);
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype);
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype);
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype);
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype);
mask_type, softmax_type, window_size_left, window_size_right, deterministic, false,
workspace_tensor.data(), stream); if (is_ragged) {
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { size_t dtype_size = typeToSize(dtype);
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; // For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream);
auto q_tensor = TensorWrapper(q, q_shape, dtype); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto k_tensor = TensorWrapper(k, k_shape, dtype); // Clear dq
auto v_tensor = TensorWrapper(v, v_shape, dtype); cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); // For packed KV, dk contains both dk and dv - clear all at once
auto dk_tensor = TensorWrapper(dk, k_shape, dtype); cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream);
auto dv_tensor = TensorWrapper(dv, v_shape, dtype); } else {
if (is_ragged) { // All separate - clear each individually
cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream);
cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream);
cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream);
} }
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic,
false, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
} }
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment