Unverified Commit 968060c1 authored by Qiu's avatar Qiu Committed by GitHub
Browse files

[bugfix] correct local_chunk_len for DCP in reorg_kvcache with long context (#28526)


Signed-off-by: default avatarQiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 5d6ce2b9
...@@ -337,6 +337,7 @@ class MLACommonPrefillMetadata: ...@@ -337,6 +337,7 @@ class MLACommonPrefillMetadata:
local_context_lens_allranks: list[list[int]] | None = None local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor | None = None padded_local_cu_seq_lens: torch.Tensor | None = None
cu_seq_lens_lst: list[list[int]] | None = None cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
block_table: torch.Tensor block_table: torch.Tensor
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
...@@ -902,6 +903,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -902,6 +903,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device, non_blocking=True device, non_blocking=True
), ),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
) )
else: else:
chunked_context_metadata = chunked_context_metadata_cls( chunked_context_metadata = chunked_context_metadata_cls(
...@@ -986,6 +988,8 @@ def reorg_kvcache( ...@@ -986,6 +988,8 @@ def reorg_kvcache(
local_context_lens_allranks: list[list[int]], local_context_lens_allranks: list[list[int]],
sum_seq_len: int, sum_seq_len: int,
max_seq_len: int, max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int, toks: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
...@@ -1001,6 +1005,9 @@ def reorg_kvcache( ...@@ -1001,6 +1005,9 @@ def reorg_kvcache(
local_context_lens_allranks: local context lengths on each CP rank. local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst. sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst. max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache. toks: the number of tokens for local gather cache.
""" """
kv_c_segments = [] kv_c_segments = []
...@@ -1012,20 +1019,31 @@ def reorg_kvcache( ...@@ -1012,20 +1019,31 @@ def reorg_kvcache(
): ):
cur_seq_len = 0 cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens): for rank, local_context_len in enumerate(local_context_lens):
if local_context_len != 0: # Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[ kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks rank * toks + src_token_idx : rank * toks
+ src_token_idx + src_token_idx
+ local_context_len + local_chunk_len
] ]
k_pe_segment = allgatered_k_pe[ k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks rank * toks + src_token_idx : rank * toks
+ src_token_idx + src_token_idx
+ local_context_len + local_chunk_len
] ]
kv_c_segments.append(kv_c_segment) kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment) k_pe_segments.append(k_pe_segment)
cur_seq_len += local_context_len cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len) max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
...@@ -1676,6 +1694,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1676,6 +1694,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
assert prefill_metadata.chunked_context.chunk_size is not None
output = None output = None
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
...@@ -1725,6 +1744,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1725,6 +1744,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks, toks=toks,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment