Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1c859a13
Unverified
Commit
1c859a13
authored
Aug 15, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 15, 2025
Browse files
[V0 Deprecation] Remove advance_step (#22969)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
74f441f4
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
9 additions
and
892 deletions
+9
-892
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/ops.h
csrc/ops.h
+0
-16
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+0
-336
csrc/prepare_inputs/advance_step.cuh
csrc/prepare_inputs/advance_step.cuh
+0
-19
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-19
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-32
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+0
-5
vllm/attention/backends/differential_flash_attn.py
vllm/attention/backends/differential_flash_attn.py
+1
-75
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+1
-75
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+2
-63
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+1
-14
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-86
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+1
-61
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+0
-21
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-67
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-2
No files found.
CMakeLists.txt
View file @
1c859a13
...
...
@@ -249,7 +249,6 @@ set(VLLM_EXT_SRC
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp"
)
...
...
csrc/ops.h
View file @
1c859a13
...
...
@@ -145,22 +145,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void
gelu_quick
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
advance_step_flashattn
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
);
void
advance_step_flashinfer
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
...
...
csrc/prepare_inputs/advance_step.cu
deleted
100644 → 0
View file @
74f441f4
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/
#include "advance_step.cuh"
namespace
prepare_inputs
{
//
template
<
int
const
num_threads
>
__global__
void
advance_step_flashattn_kernel
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
return
;
}
int
cur_query_id
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
if
(
cur_query_id
>=
num_queries
)
{
return
;
}
// Update input_tokens
input_tokens_ptr
[
cur_query_id
]
=
sampled_token_ids_ptr
[
cur_query_id
];
int
seq_len
=
seq_lens_ptr
[
cur_query_id
];
int
next_seq_len
=
seq_len
+
1
;
int
next_input_pos
=
next_seq_len
-
1
;
// Update seq_lens
seq_lens_ptr
[
cur_query_id
]
=
next_seq_len
;
// Update input_positions
input_positions_ptr
[
cur_query_id
]
=
next_input_pos
;
int
const
*
seq_block_tables_ptr
=
block_tables_ptr
+
block_tables_stride
*
cur_query_id
;
int
block_index
=
next_input_pos
/
block_size
;
int
block_offset
=
next_input_pos
%
block_size
;
int
slot_num
=
seq_block_tables_ptr
[
block_index
]
*
block_size
+
block_offset
;
// Update slot_mapping
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
}
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
const
&
t
,
int64_t
const
size_0
,
int64_t
const
size_1
,
c10
::
ScalarType
const
type
)
{
bool
size_0_cond
=
true
;
if
(
size_0
!=
-
1
)
{
size_0_cond
=
t
.
size
(
0
)
==
size_0
;
}
bool
size_1_cond
=
true
;
if
(
size_1
!=
-
1
)
{
size_1_cond
=
t
.
size
(
1
)
==
size_1
;
}
bool
is_contiguous
=
t
.
is_contiguous
();
bool
same_type
=
t
.
dtype
()
==
type
;
bool
pass
=
size_0_cond
&&
size_1_cond
&&
is_contiguous
&&
same_type
;
if
(
!
pass
)
{
TORCH_CHECK
(
false
,
"tensor: name = "
,
name
,
", shape = "
,
t
.
sizes
(),
" is_cont = "
,
t
.
is_contiguous
(),
", type = "
,
t
.
dtype
(),
" is not as expected: shape = ["
,
size_0
,
", "
,
size_1
,
"], type = "
,
type
);
}
}
/// each thread processes a block per query
__global__
void
advance_step_flashinfer_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_last_page_len_ptr
,
int
*
block_table_bound_ptr
)
{
int
const
n_pad
=
num_seqs
-
num_queries
;
if
(
n_pad
&&
blockIdx
.
x
==
0
)
{
// Handle cuda graph padding
int
const
offset
=
num_queries
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_pad
;
i
+=
blockDim
.
x
)
{
input_tokens_ptr
[
offset
+
i
]
=
0
;
input_positions_ptr
[
offset
+
i
]
=
0
;
slot_mapping_ptr
[
offset
+
i
]
=
-
1
;
}
}
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
<
num_query_blocks
)
{
int
cur_query_id
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
if
(
cur_query_id
<
num_queries
)
{
// Update input_tokens
input_tokens_ptr
[
cur_query_id
]
=
sampled_token_ids_ptr
[
cur_query_id
];
int
seq_len
=
seq_lens_ptr
[
cur_query_id
];
int
next_seq_len
=
seq_len
+
1
;
int
next_input_pos
=
next_seq_len
-
1
;
// Update seq_lens
seq_lens_ptr
[
cur_query_id
]
=
next_seq_len
;
// Update input_positions
input_positions_ptr
[
cur_query_id
]
=
next_input_pos
;
int
const
*
seq_block_tables_ptr
=
block_tables_ptr
+
block_tables_stride
*
cur_query_id
;
int
block_index
=
next_input_pos
/
block_size
;
int
block_offset
=
next_input_pos
%
block_size
;
// Update paged_kv_last_page_len
paged_kv_last_page_len_ptr
[
cur_query_id
]
=
block_offset
+
1
;
int
slot_num
=
seq_block_tables_ptr
[
block_index
]
*
block_size
+
block_offset
;
// Update slot_mapping
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
block_table_bound_ptr
[
cur_query_id
]
=
div_ceil
(
next_seq_len
,
block_size
);
}
}
}
__global__
void
advance_step_flashinfer_indptr_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// Update paged_kv_indptr
if
(
idx
==
0
)
{
paged_kv_indptr_ptr
[
idx
]
=
0
;
}
if
(
idx
<
num_queries
)
{
int
sum
=
0
;
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
sum
+=
block_table_bound_ptr
[
i
];
}
paged_kv_indptr_ptr
[
idx
+
1
]
=
sum
;
}
}
__global__
void
advance_step_flashinfer_indices_kernel
(
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
max_num_blocks_per_seq
,
int
*
paged_kv_indices_ptr
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
// note: max_num_blocks_per_seq = block_tables.stride(0)
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// when cuda graphs are enabled, paged_kv_indptr tensor
// has to be updated for the padded queries
// tid represents a query# for paged_kv_indptr tensor
if
(
num_queries
<
tid
&&
tid
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
tid
]
=
paged_kv_indptr_ptr
[
num_queries
];
}
// each thread processes a block_ptr in block_tables
// block_tables shape: [num_queries, max_num_blocks_per_seq]
// paged_kv_indices is flattened block_tables.
for
(
int
idx
=
tid
;
idx
<
(
num_seqs
*
max_num_blocks_per_seq
);
idx
+=
(
gridDim
.
x
*
blockDim
.
x
))
{
// block_tables-row = paged_kv_indptr[queryNum]
int
queryNum
=
idx
/
max_num_blocks_per_seq
;
int
col
=
idx
%
max_num_blocks_per_seq
;
if
(
queryNum
<
num_queries
&&
col
<
block_table_bound_ptr
[
queryNum
])
{
int
indices_arr_idx
=
paged_kv_indptr_ptr
[
queryNum
]
+
col
;
int
block_tables_idx
=
queryNum
*
max_num_blocks_per_seq
+
col
;
paged_kv_indices_ptr
[
indices_arr_idx
]
=
block_tables_ptr
[
block_tables_idx
];
}
}
}
void
advance_step_flashattn
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step_flashattn:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"sampled_token_ids"
,
sampled_token_ids
,
num_queries
,
1
,
at
::
kLong
);
verify_tensor
(
"input_positions"
,
input_positions
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"seq_lens"
,
seq_lens
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"slot_mapping"
,
slot_mapping
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"block_tables"
,
block_tables
,
num_seqs
,
-
1
,
at
::
kInt
);
int
dev
=
sampled_token_ids
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
advance_step_flashattn_kernel
<
max_threads
>
<<<
blocks
,
max_threads
,
0
,
stream
>>>
(
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_positions
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
slot_mapping
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
));
}
void
advance_step_flashinfer
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
,
// type: int
torch
::
Tensor
&
paged_kv_indices
,
// type: int
torch
::
Tensor
&
paged_kv_indptr
,
// type: int
torch
::
Tensor
&
paged_kv_last_page_len
,
// type: int
torch
::
Tensor
&
block_table_bound
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step_flashinfer:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_tables.stride(0) = %zu
\n
"
,
block_tables
.
stride
(
0
));
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
// at::kLong);
verify_tensor
(
"input_positions"
,
input_positions
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"seq_lens"
,
seq_lens
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"slot_mapping"
,
slot_mapping
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"block_tables"
,
block_tables
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_indices"
,
paged_kv_indices
,
-
1
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_indptr"
,
paged_kv_indptr
,
num_seqs
+
1
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_last_page_len"
,
paged_kv_last_page_len
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"block_table_bound"
,
block_table_bound
,
num_seqs
,
-
1
,
at
::
kInt
);
int
dev
=
sampled_token_ids
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
int
threads
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
TORCH_CHECK
((
blocks
*
threads
>
num_queries
),
"multi-step: not enough threads to map to num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
if
(
logging
)
{
printf
(
"launching kernels with %d blocks and %d threads
\n
"
,
blocks
,
threads
);
}
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_positions
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
slot_mapping
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_last_page_len
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indptr_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
*>
(
paged_kv_indptr
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
num_seqs
,
num_queries
,
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
paged_kv_indptr
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
}
}
// namespace prepare_inputs
void
advance_step_flashattn
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
)
{
prepare_inputs
::
advance_step_flashattn
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
);
}
void
advance_step_flashinfer
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bound
)
{
prepare_inputs
::
advance_step_flashinfer
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
}
csrc/prepare_inputs/advance_step.cuh
deleted
100644 → 0
View file @
74f441f4
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace
prepare_inputs
{
static
constexpr
int
max_threads
=
256
;
static
constexpr
bool
logging
=
false
;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
}
// namespace prepare_inputs
csrc/torch_bindings.cpp
View file @
1c859a13
...
...
@@ -142,25 +142,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_quick(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_quick"
,
torch
::
kCUDA
,
&
gelu_quick
);
// prepare_inputs advance_step
ops
.
def
(
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()"
);
ops
.
impl
(
"advance_step_flashattn"
,
torch
::
kCUDA
,
&
advance_step_flashattn
);
ops
.
def
(
"advance_step_flashinfer("
" int num_seqs, int num_queries, int block_size,"
" Tensor! input_tokens, Tensor sampled_token_ids,"
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
" Tensor block_tables, Tensor! paged_kv_indices,"
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
" Tensor! block_table_bounds"
") -> ()"
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
...
...
vllm/_custom_ops.py
View file @
1c859a13
...
...
@@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
repetition_penalties
)
def
advance_step_flashattn
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
)
->
None
:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return
torch
.
ops
.
_C
.
advance_step_flashattn
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
)
def
advance_step_flashinfer
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
paged_kv_indices
:
torch
.
Tensor
,
paged_kv_indptr
:
torch
.
Tensor
,
paged_kv_last_page_len
:
torch
.
Tensor
,
block_table_bound
:
torch
.
Tensor
)
->
None
:
return
torch
.
ops
.
_C
.
advance_step_flashinfer
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
)
# fused quant layer norm ops
def
rms_norm_dynamic_per_token_quant
(
input
:
torch
.
Tensor
,
...
...
vllm/attention/backends/abstract.py
View file @
1c859a13
...
...
@@ -101,11 +101,6 @@ class AttentionBackend(ABC):
)
->
None
:
raise
NotImplementedError
def
advance_step
(
self
,
model_input
:
"ModelRunnerInputBase"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
)
->
None
:
raise
NotImplementedError
@
classmethod
def
full_cls_name
(
cls
)
->
tuple
[
str
,
str
]:
return
(
cls
.
__module__
,
cls
.
__qualname__
)
...
...
vllm/attention/backends/differential_flash_attn.py
View file @
1c859a13
...
...
@@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
logger
=
init_logger
(
__name__
)
...
...
@@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata):
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
if
turn_prefills_into_decodes
:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens
is
not
None
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
DifferentialFlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
DifferentialFlashAttentionMetadata
]):
...
...
vllm/attention/backends/flash_attn.py
View file @
1c859a13
...
...
@@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
logger
=
init_logger
(
__name__
)
...
...
@@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata):
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
if
turn_prefills_into_decodes
:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens
is
not
None
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
...
...
vllm/attention/backends/flashinfer.py
View file @
1c859a13
...
...
@@ -51,8 +51,7 @@ from vllm.utils.flashinfer import use_trtllm_attention
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -428,7 +427,7 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# used for GPU
in-place advance_step
# used for GPU
operations
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -587,66 +586,6 @@ class FlashInferMetadata(AttentionMetadata):
return
None
return
self
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
if
turn_prefills_into_decodes
:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert
self
.
decode_query_len
==
1
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens_tensor
is
not
None
assert
num_seqs
>
0
assert
num_queries
>
0
assert
model_input
.
attn_metadata
is
not
None
assert
sampled_token_ids
is
not
None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
model_input
.
input_tokens
[:
num_queries
]
=
sampled_token_ids
.
flatten
()
# Update GPU tensors
ops
.
advance_step_flashinfer
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
model_input
.
input_tokens
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
,
paged_kv_indices
=
self
.
paged_kv_indices
,
paged_kv_indptr
=
self
.
paged_kv_indptr
,
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
,
block_table_bound
=
self
.
block_table_bound
)
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
...
...
vllm/attention/backends/flashmla.py
View file @
1c859a13
...
...
@@ -3,7 +3,7 @@
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata
,
is_flashmla_supported
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
class
FlashMLABackend
(
MLACommonBackend
):
...
...
@@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata):
self
.
decode_num_splits
return
decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
raise
NotImplementedError
(
"advance_step is not implemented for FlashMLA"
)
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
...
...
vllm/attention/backends/mla/common.py
View file @
1c859a13
...
...
@@ -234,8 +234,7 @@ except ImportError:
flash_attn_varlen_func
=
None
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
is_hip
=
current_platform
.
is_rocm
()
...
...
@@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata):
is_profile_run
=
self
.
is_profile_run
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
if
turn_prefills_into_decodes
:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens
is
not
None
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
_ops_advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
)
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
T
],
Generic
[
T
]):
"""
...
...
vllm/attention/backends/placeholder_attn.py
View file @
1c859a13
...
...
@@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
)
from
vllm.utils
import
async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
...
...
@@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
assert
not
turn_prefills_into_decodes
,
\
(
"Multi-Step + Chunked-Prefill is not supported for attention-free"
"models. turn_prefills_into_decodes is a "
"Multi-Step + Chunked-Prefill specific parameter."
)
assert
self
.
seq_lens
is
not
None
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
# Update sequences, masking off entries greater than num_queries
device
=
self
.
seq_lens_tensor
.
device
mask
=
torch
.
arange
(
self
.
seq_lens_tensor
.
size
(
0
),
device
=
device
)
<
num_queries
self
.
seq_lens_tensor
+=
mask
.
to
(
self
.
seq_lens_tensor
.
dtype
)
if
sampled_token_ids
is
not
None
:
model_input
.
input_tokens
.
masked_scatter_
(
mask
,
sampled_token_ids
[:
num_queries
])
class
PlaceholderAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
PlaceholderAttentionMetadata
]):
...
...
vllm/attention/backends/rocm_aiter_mla.py
View file @
1c859a13
...
...
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union
import
torch
import
vllm._custom_ops
as
ops
import
vllm.envs
as
envs
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
...
...
@@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata):
return
self
.
_cached_decode_metadata
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
ops
.
advance_step_flashinfer
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
,
paged_kv_indices
=
self
.
paged_kv_indices
,
paged_kv_indptr
=
self
.
paged_kv_indptr
,
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
,
block_table_bound
=
self
.
block_table_bound
)
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[[]]
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
1c859a13
...
...
@@ -4,7 +4,7 @@
import
itertools
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -23,9 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape
)
from
vllm.platforms
import
current_platform
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE_ROCM
=
256
...
...
@@ -261,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
assert
not
turn_prefills_into_decodes
,
\
(
"Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter."
)
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
ROCmFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
...
...
vllm/worker/model_runner.py
View file @
1c859a13
...
...
@@ -762,8 +762,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
has Prefills (if any). The rest of the steps are guaranteed to be all
decodes. In this case, we set up the padding as if all the sequences
are decodes so we may run all steps except the first step in CUDA graph
mode. The padding is accounted for in the multi-step `advance_step`
family of functions.
mode.
Args:
num_seqs (int): Number of sequences scheduled to run.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment