Commit a441a5d9 authored by zhuwenwen's avatar zhuwenwen
Browse files

update common.py

parent 109c414a
......@@ -860,7 +860,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.cudnn_workspace = self.cudnn_workspace
# TODO @ wangming
# decode_metadata = None
decode_metadata = None
# if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
# query_lens = self.num_scheduled_tokens_np[:num_decodes]
......@@ -920,14 +920,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment