Unverified Commit b88f727b authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[JAX] Make all jax attention calls use non-packed common calls (#2358)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* add notes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 262c184e
...@@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so ...@@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so
return view; return view;
} }
// Helper function to calculate stride for packed QKV tensor unpacking // Helper function to calculate stride in bytes for packed QKV tensor unpacking
size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype,
size_t h, size_t d) { size_t h, size_t d) {
size_t stride = 0; size_t stride = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment