Commit 2b8d795b authored by zhuwenwen's avatar zhuwenwen
Browse files

add rocm merge_attn_states

parent f572ca96
...@@ -172,7 +172,6 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -172,7 +172,6 @@ void paged_attention_v2_opt_tc_with_mask(
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output, void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_output,
...@@ -180,6 +179,8 @@ void merge_attn_states(torch::Tensor& output, ...@@ -180,6 +179,8 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes( void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
......
...@@ -220,7 +220,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -220,7 +220,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask); ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask);
#ifndef USE_ROCM
// Merge attn states // Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case) // can be used to combine partial attention results (in the split-KV case)
...@@ -234,6 +233,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -234,6 +233,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"); " Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def( ops.def(
"convert_vertical_slash_indexes(" "convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, " " Tensor! block_count, Tensor! block_offset, "
......
...@@ -31,7 +31,7 @@ def merge_attn_states( ...@@ -31,7 +31,7 @@ def merge_attn_states(
return headdim % 4 == 0 return headdim % 4 == 0
return headdim % 8 == 0 return headdim % 8 == 0
if (current_platform.is_cuda() and supported_dtypes(output) if (current_platform.is_cuda() or current_platform.is_rocm() and supported_dtypes(output)
and supported_headdim(output)): and supported_headdim(output)):
from vllm._custom_ops import merge_attn_states from vllm._custom_ops import merge_attn_states
return merge_attn_states(output, prefix_output, prefix_lse, return merge_attn_states(output, prefix_output, prefix_lse,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment