Commit 2cbda743 authored by zhuwenwen's avatar zhuwenwen
Browse files

update DeepseekV32IndexerCache

parent 734f52d8
...@@ -702,7 +702,7 @@ def sparse_attn_indexer( ...@@ -702,7 +702,7 @@ def sparse_attn_indexer(
else: else:
logits = gemmopt.paged_mqa_logits( logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens, padded_q_fp8_decode_tokens,
kv_cache if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else kv_cache.to(torch.bfloat16), kv_cache,
weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32), weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens, decode_metadata.seq_lens,
decode_metadata.block_table, decode_metadata.block_table,
...@@ -829,8 +829,8 @@ class Indexer(nn.Module): ...@@ -829,8 +829,8 @@ class Indexer(nn.Module):
# per self.quant_block_size element # per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache( self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim + head_dim=self.head_dim +
self.head_dim // self.quant_block_size * 4, self.head_dim // self.quant_block_size * 4 if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else self.head_dim,
dtype=torch.uint8, dtype=torch.uint8 if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16,
prefix=f"{prefix}.k_cache", prefix=f"{prefix}.k_cache",
cache_config=cache_config) cache_config=cache_config)
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment