Commit 734f52d8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update sparse_attn_indexer

parent f441aca2
...@@ -872,28 +872,11 @@ class Indexer(nn.Module): ...@@ -872,28 +872,11 @@ class Indexer(nn.Module):
-1) * q_scale * self.softmax_scale * self.n_head**-0.5 -1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1) weights = weights.squeeze(-1)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
else:
return torch.ops.vllm.sparse_attn_indexer( return torch.ops.vllm.sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, self.k_cache.prefix,
self.k_cache.kv_cache[0], self.k_cache.kv_cache[0],
q, q_fp8 if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else q,
k, k,
weights, weights,
self.quant_block_size, self.quant_block_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment