Commit 2b8d795b authored by zhuwenwen's avatar zhuwenwen
Browse files

add rocm merge_attn_states

parent f572ca96
......@@ -172,7 +172,6 @@ void paged_attention_v2_opt_tc_with_mask(
const int64_t attn_masks_stride=0);
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
......@@ -180,6 +179,8 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
......
......@@ -220,7 +220,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
......@@ -234,6 +233,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
......
......@@ -31,7 +31,7 @@ def merge_attn_states(
return headdim % 4 == 0
return headdim % 8 == 0
if (current_platform.is_cuda() and supported_dtypes(output)
if (current_platform.is_cuda() or current_platform.is_rocm() and supported_dtypes(output)
and supported_headdim(output)):
from vllm._custom_ops import merge_attn_states
return merge_attn_states(output, prefix_output, prefix_lse,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment