"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "2712bb95cb4a4f7d1f2b8b473a2240ac3d6e7e58"
Unverified Commit 421084cf authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

cache sequence chunk ids for reordering (#1751)


Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ce0b46c4
......@@ -44,6 +44,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
)
_cu_seqlens_info_with_cp_cache = {}
_seq_chunk_ids_cache_for_reordering_before_attn = {}
_seq_chunk_ids_cache_for_reordering_after_attn = {}
def flash_attn_p2p_communicate(
......@@ -193,11 +195,14 @@ def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
be contigupus before attention compute. This function is to compute sequence chunk ids for
reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
return chunk_ids
global _seq_chunk_ids_cache_for_reordering_before_attn
if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_before_attn:
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
for rank in range(cp_size):
chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
_seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)] = chunk_ids
return _seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)]
@jit_fuser
......@@ -207,11 +212,14 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
We need to reorder sequence chunks back to discontiguous after attention compute. This function
is to compute sequence chunk ids for reordering.
"""
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
return chunk_ids
global _seq_chunk_ids_cache_for_reordering_after_attn
if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_after_attn:
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
for rank in range(cp_size):
chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
_seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)] = chunk_ids
return _seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)]
@jit_fuser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment