Commit 7cec75a7 authored by liuchy5's avatar liuchy5
Browse files

修改sparse_attn hip后端

parent ce52b8a8
......@@ -184,7 +184,7 @@ class CustomOp(nn.Module):
return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm():
return self.forward_cuda
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_tpu():
......
......@@ -296,8 +296,7 @@ class SparseAttnIndexer(CustomOp):
if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm():
# return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_fp8, k, weights)
return self.forward_hip(hidden_states, q_fp8, k, weights)
else:
raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for "
......@@ -349,9 +348,22 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
)
else:
raise RuntimeError(
"Sparse attention indexer ROCm custom op requires ROCm "
"Aiter ops to be enabled."
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment