Commit 7cec75a7 authored by liuchy5's avatar liuchy5
Browse files

修改sparse_attn hip后端

parent ce52b8a8
...@@ -184,7 +184,7 @@ class CustomOp(nn.Module): ...@@ -184,7 +184,7 @@ class CustomOp(nn.Module):
return self.maybe_compile(self.forward_native, enable=compile_native) return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm(): if current_platform.is_rocm():
return self.forward_cuda return self.forward_hip
elif current_platform.is_cpu(): elif current_platform.is_cpu():
return self.forward_cpu return self.forward_cpu
elif current_platform.is_tpu(): elif current_platform.is_tpu():
......
...@@ -296,8 +296,7 @@ class SparseAttnIndexer(CustomOp): ...@@ -296,8 +296,7 @@ class SparseAttnIndexer(CustomOp):
if current_platform.is_cuda(): if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights) return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
# return self.forward_hip(hidden_states, q_fp8, k, weights) return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_fp8, k, weights)
else: else:
raise NotImplementedError( raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for " "SparseAttnIndexer native forward is only implemented for "
...@@ -349,9 +348,22 @@ class SparseAttnIndexer(CustomOp): ...@@ -349,9 +348,22 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len, self.max_model_len,
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
else: else:
raise RuntimeError( return torch.ops.vllm.sparse_attn_indexer(
"Sparse attention indexer ROCm custom op requires ROCm " hidden_states,
"Aiter ops to be enabled." self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment