Commit 8001970c authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-fix_custom_op' into 'v0.15.1-dev'

修改sparse_attn hip后端

See merge request dcutoolkit/deeplearing/vllm!498
parents ce52b8a8 7cec75a7
...@@ -184,7 +184,7 @@ class CustomOp(nn.Module): ...@@ -184,7 +184,7 @@ class CustomOp(nn.Module):
return self.maybe_compile(self.forward_native, enable=compile_native) return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm(): if current_platform.is_rocm():
return self.forward_cuda return self.forward_hip
elif current_platform.is_cpu(): elif current_platform.is_cpu():
return self.forward_cpu return self.forward_cpu
elif current_platform.is_tpu(): elif current_platform.is_tpu():
......
...@@ -296,8 +296,7 @@ class SparseAttnIndexer(CustomOp): ...@@ -296,8 +296,7 @@ class SparseAttnIndexer(CustomOp):
if current_platform.is_cuda(): if current_platform.is_cuda():
return self.forward_cuda(hidden_states, q_fp8, k, weights) return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
# return self.forward_hip(hidden_states, q_fp8, k, weights) return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_fp8, k, weights)
else: else:
raise NotImplementedError( raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for " "SparseAttnIndexer native forward is only implemented for "
...@@ -349,9 +348,22 @@ class SparseAttnIndexer(CustomOp): ...@@ -349,9 +348,22 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len, self.max_model_len,
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
else: else:
raise RuntimeError( return torch.ops.vllm.sparse_attn_indexer(
"Sparse attention indexer ROCm custom op requires ROCm " hidden_states,
"Aiter ops to be enabled." self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment