"qa/vscode:/vscode.git/clone" did not exist on "26370b117169aec87df9e86f90814a4faabbcc09"
Unverified Commit 421084cf authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

cache sequence chunk ids for reordering (#1751)


Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ce0b46c4
...@@ -44,6 +44,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -44,6 +44,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
) )
_cu_seqlens_info_with_cp_cache = {} _cu_seqlens_info_with_cp_cache = {}
_seq_chunk_ids_cache_for_reordering_before_attn = {}
_seq_chunk_ids_cache_for_reordering_after_attn = {}
def flash_attn_p2p_communicate( def flash_attn_p2p_communicate(
...@@ -193,11 +195,14 @@ def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device): ...@@ -193,11 +195,14 @@ def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device):
be contigupus before attention compute. This function is to compute sequence chunk ids for be contigupus before attention compute. This function is to compute sequence chunk ids for
reordering. reordering.
""" """
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) global _seq_chunk_ids_cache_for_reordering_before_attn
for rank in range(cp_size): if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_before_attn:
chunk_ids[rank] = 2 * rank chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 for rank in range(cp_size):
return chunk_ids chunk_ids[rank] = 2 * rank
chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1
_seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)] = chunk_ids
return _seq_chunk_ids_cache_for_reordering_before_attn[(cp_size, device)]
@jit_fuser @jit_fuser
...@@ -207,11 +212,14 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): ...@@ -207,11 +212,14 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device):
We need to reorder sequence chunks back to discontiguous after attention compute. This function We need to reorder sequence chunks back to discontiguous after attention compute. This function
is to compute sequence chunk ids for reordering. is to compute sequence chunk ids for reordering.
""" """
chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) global _seq_chunk_ids_cache_for_reordering_after_attn
for rank in range(cp_size): if (cp_size, device) not in _seq_chunk_ids_cache_for_reordering_after_attn:
chunk_ids[2 * rank] = rank chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device)
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 for rank in range(cp_size):
return chunk_ids chunk_ids[2 * rank] = rank
chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1
_seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)] = chunk_ids
return _seq_chunk_ids_cache_for_reordering_after_attn[(cp_size, device)]
@jit_fuser @jit_fuser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment