Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6c0f365
Unverified
Commit
a6c0f365
authored
Sep 12, 2024
by
William Lin
Committed by
GitHub
Sep 12, 2024
Browse files
[multi-step] add flashinfer backend (#7928)
parent
f2e263b8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
371 additions
and
84 deletions
+371
-84
csrc/ops.h
csrc/ops.h
+15
-4
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+200
-25
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+13
-2
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+9
-3
vllm/_custom_ops.py
vllm/_custom_ops.py
+29
-9
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+3
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+9
-9
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+77
-10
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+16
-21
No files found.
csrc/ops.h
View file @
a6c0f365
...
...
@@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void
gelu_quick
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
void
advance_step_flashattn
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
);
void
advance_step_flashinfer
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
);
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bounds
);
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
...
...
csrc/prepare_inputs/advance_step.cu
View file @
a6c0f365
...
...
@@ -12,12 +12,10 @@ namespace prepare_inputs {
//
template
<
int
const
num_threads
>
__global__
void
advance_step_kernel
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
__global__
void
advance_step_flashattn_kernel
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
...
...
@@ -79,7 +77,82 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
}
}
void
advance_step
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
__global__
void
advance_step_flashinfer_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_last_page_len_ptr
,
int
*
block_table_bound_ptr
)
{
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
<
num_query_blocks
)
{
int
cur_query_id
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
if
(
cur_query_id
<
num_queries
)
{
// Update input_tokens
input_tokens_ptr
[
cur_query_id
]
=
sampled_token_ids_ptr
[
cur_query_id
];
int
seq_len
=
seq_lens_ptr
[
cur_query_id
];
int
next_seq_len
=
seq_len
+
1
;
int
next_input_pos
=
next_seq_len
-
1
;
// Update seq_lens
seq_lens_ptr
[
cur_query_id
]
=
next_seq_len
;
// Update input_positions
input_positions_ptr
[
cur_query_id
]
=
next_input_pos
;
int
const
*
seq_block_tables_ptr
=
block_tables_ptr
+
block_tables_stride
*
cur_query_id
;
int
block_index
=
next_input_pos
/
block_size
;
int
block_offset
=
next_input_pos
%
block_size
;
// Update paged_kv_last_page_len
paged_kv_last_page_len_ptr
[
cur_query_id
]
=
block_offset
+
1
;
int
slot_num
=
seq_block_tables_ptr
[
block_index
]
*
block_size
+
block_offset
;
// Update slot_mapping
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
block_table_bound_ptr
[
cur_query_id
]
=
div_ceil
(
next_seq_len
,
block_size
);
}
}
}
__global__
void
advance_step_flashinfer_indptr_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
// Update paged_kv_indptr
if
(
idx
<
num_queries
)
{
int
sum
=
0
;
for
(
int
i
=
0
;
i
<=
idx
;
++
i
)
{
sum
+=
block_table_bound_ptr
[
i
];
}
paged_kv_indptr_ptr
[
idx
+
1
]
=
sum
;
}
}
__global__
void
advance_step_flashinfer_indices_kernel
(
int
num_threads
,
int
num_seqs
,
int
num_queries
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
,
int
*
paged_kv_indices_ptr
,
int
*
paged_kv_indptr_ptr
,
int
*
block_table_bound_ptr
)
{
int
idx
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
int
row
=
idx
/
block_tables_stride
;
int
col
=
idx
%
block_tables_stride
;
if
(
row
<
num_queries
&&
col
<
block_table_bound_ptr
[
row
])
{
paged_kv_indices_ptr
[
paged_kv_indptr_ptr
[
row
]
+
col
]
=
block_tables_ptr
[
row
*
block_tables_stride
+
col
];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if
(
num_queries
<
row
&&
row
<=
num_seqs
)
{
paged_kv_indptr_ptr
[
row
]
=
paged_kv_indptr_ptr
[
num_queries
];
}
}
void
advance_step_flashattn
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
...
...
@@ -88,7 +161,7 @@ void advance_step(int num_seqs, int num_queries, int block_size,
torch
::
Tensor
&
block_tables
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step:
\n
"
);
printf
(
"advance_step
_flashattn
:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
...
...
@@ -108,7 +181,8 @@ void advance_step(int num_seqs, int num_queries, int block_size,
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
advance_step_kernel
<
max_threads
><<<
blocks
,
max_threads
,
0
,
stream
>>>
(
advance_step_flashattn_kernel
<
max_threads
>
<<<
blocks
,
max_threads
,
0
,
stream
>>>
(
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
...
...
@@ -119,13 +193,114 @@ void advance_step(int num_seqs, int num_queries, int block_size,
block_tables
.
stride
(
0
));
}
void
advance_step_flashinfer
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
,
// type: int
torch
::
Tensor
&
paged_kv_indices
,
// type: int
torch
::
Tensor
&
paged_kv_indptr
,
// type: int
torch
::
Tensor
&
paged_kv_last_page_len
,
// type: int
torch
::
Tensor
&
block_table_bound
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step_flashinfer:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
printf
(
" block_tables.stride(0) = %d
\n
"
,
block_tables
.
stride
(
0
));
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
// at::kLong);
verify_tensor
(
"input_positions"
,
input_positions
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"seq_lens"
,
seq_lens
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"slot_mapping"
,
slot_mapping
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"block_tables"
,
block_tables
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_indices"
,
paged_kv_indices
,
-
1
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_indptr"
,
paged_kv_indptr
,
num_seqs
+
1
,
-
1
,
at
::
kInt
);
verify_tensor
(
"paged_kv_last_page_len"
,
paged_kv_last_page_len
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"block_table_bound"
,
block_table_bound
,
num_seqs
,
-
1
,
at
::
kInt
);
int
dev
=
sampled_token_ids
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
int
threads
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
cudaDeviceGetAttribute
(
&
threads
,
cudaDevAttrMaxThreadsPerBlock
,
dev
);
if
(
logging
)
{
printf
(
"launching kernel with %d blocks
\n
"
,
blocks
);
}
// TODO(will): support arbitrary block_tables stride
if
((
blocks
*
threads
)
/
block_tables
.
stride
(
0
)
<
num_queries
)
{
TORCH_CHECK
(
false
,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,"
,
" increasing the block size or take smaller steps."
,
" num_queries = "
,
num_queries
,
" block_tables.stride(0) = "
,
block_tables
.
stride
(
0
),
" blocks = "
,
blocks
,
" max_threads = "
,
threads
);
}
advance_step_flashinfer_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_positions
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
slot_mapping
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_last_page_len
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indptr_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
*>
(
paged_kv_indptr
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
advance_step_flashinfer_indices_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
threads
,
num_seqs
,
num_queries
,
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
),
reinterpret_cast
<
int
*>
(
paged_kv_indices
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
paged_kv_indptr
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
block_table_bound
.
data_ptr
()));
}
}
// namespace prepare_inputs
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
void
advance_step_flashattn
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
)
{
prepare_inputs
::
advance_step_flashattn
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
);
}
void
advance_step_flashinfer
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
)
{
prepare_inputs
::
advance_step
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
);
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
paged_kv_indices
,
torch
::
Tensor
&
paged_kv_indptr
,
torch
::
Tensor
&
paged_kv_last_page_len
,
torch
::
Tensor
&
block_table_bound
)
{
prepare_inputs
::
advance_step_flashinfer
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
);
}
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
a6c0f365
...
...
@@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// prepare_inputs advance_step
ops
.
def
(
"advance_step(int num_seqs, int num_queries, int block_size, "
"advance_step
_flashattn
(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()"
);
ops
.
impl
(
"advance_step"
,
torch
::
kCUDA
,
&
advance_step
);
ops
.
impl
(
"advance_step_flashattn"
,
torch
::
kCUDA
,
&
advance_step_flashattn
);
ops
.
def
(
"advance_step_flashinfer("
" int num_seqs, int num_queries, int block_size,"
" Tensor! input_tokens, Tensor sampled_token_ids,"
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
" Tensor block_tables, Tensor! paged_kv_indices,"
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
" Tensor! block_table_bounds"
") -> ()"
);
ops
.
impl
(
"advance_step_flashinfer"
,
torch
::
kCUDA
,
&
advance_step_flashinfer
);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
...
...
tests/multi_step/test_correctness_async_llm.py
View file @
a6c0f365
# Test the AsyncLLMEngine with multi-step-decoding
from
typing
import
List
,
Optional
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
..models.utils
import
check_logprobs_close
from
..utils
import
(
completions_with_server_args
,
get_client_text_generations
,
get_client_text_logprob_generations
)
...
...
@@ -33,8 +34,9 @@ DEFAULT_SERVER_ARGS: List[str] = [
@
pytest
.
mark
.
parametrize
(
"eager_mode"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"is_async"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"is_async"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"FLASHINFER"
,
"FLASH_ATTN"
])
@
pytest
.
mark
.
asyncio
async
def
test_multi_step
(
example_prompts
,
...
...
@@ -46,6 +48,8 @@ async def test_multi_step(
num_prompts
:
int
,
is_async
:
bool
,
num_logprobs
:
Optional
[
int
],
attention_backend
:
str
,
monkeypatch
,
)
->
None
:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.
...
...
@@ -71,6 +75,8 @@ async def test_multi_step(
completions endpoint; `None` -> no logprobs
"""
override_backend_env_variable
(
monkeypatch
,
attention_backend
)
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
prompts
=
prompts
*
((
num_prompts
//
len
(
prompts
))
+
1
)
...
...
vllm/_custom_ops.py
View file @
a6c0f365
...
...
@@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
def
advance_step
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
def
advance_step_flashattn
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
)
->
None
:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return
torch
.
ops
.
_C
.
advance_step
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
)
return
torch
.
ops
.
_C
.
advance_step_flashattn
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
)
def
advance_step_flashinfer
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
paged_kv_indices
:
torch
.
Tensor
,
paged_kv_indptr
:
torch
.
Tensor
,
paged_kv_last_page_len
:
torch
.
Tensor
,
block_table_bound
:
torch
.
Tensor
)
->
None
:
return
torch
.
ops
.
_C
.
advance_step_flashinfer
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
,
paged_kv_indices
,
paged_kv_indptr
,
paged_kv_last_page_len
,
block_table_bound
)
# quantization ops
...
...
vllm/attention/backends/abstract.py
View file @
a6c0f365
...
...
@@ -83,7 +83,9 @@ class AttentionBackend(ABC):
)
->
None
:
raise
NotImplementedError
def
advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
):
def
advance_step
(
self
,
model_input
:
"ModelRunnerInputBase"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
)
->
None
:
raise
NotImplementedError
...
...
vllm/attention/backends/flash_attn.py
View file @
a6c0f365
...
...
@@ -380,7 +380,7 @@ class FlashAttentionMetadata(AttentionMetadata):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step
(
num_seqs
=
num_seqs
,
ops
.
advance_step
_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
...
...
vllm/attention/backends/flashinfer.py
View file @
a6c0f365
...
...
@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# used for GPU in-place advance_step
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
...
...
@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
assert
self
.
block_table_bound
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
assert
batch_size
>=
0
# We will use flash attention for profiling to
...
...
@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
block_table_bound
=
self
.
block_table_bound
.
to
(
self
.
device
)
self
.
seq_lens_tensor
=
self
.
seq_lens_tensor
.
to
(
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
prefill_wrapper
.
end_forward
()
self
.
prefill_wrapper
.
begin_forward
(
...
...
@@ -335,7 +344,6 @@ class FlashInferMetadata(AttentionMetadata):
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
)
else
:
if
not
self
.
use_cuda_graph
:
assert
self
.
paged_kv_indices
is
not
None
assert
self
.
paged_kv_indptr
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
...
...
@@ -343,6 +351,11 @@ class FlashInferMetadata(AttentionMetadata):
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
# handle model warmup path
if
self
.
block_table_bound
is
not
None
:
self
.
block_table_bound
=
self
.
block_table_bound
.
to
(
self
.
device
)
if
self
.
seq_lens_tensor
is
not
None
:
self
.
seq_lens_tensor
=
self
.
seq_lens_tensor
.
to
(
self
.
device
)
assert
self
.
decode_wrapper
is
not
None
self
.
decode_wrapper
.
end_forward
()
...
...
@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
return
self
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
):
"""
Update metadata in-place to advance one decode step.
"""
assert
num_seqs
>
0
assert
num_queries
>
0
assert
model_input
.
attn_metadata
is
not
None
assert
sampled_token_ids
is
not
None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
model_input
.
input_tokens
[:
num_queries
]
=
sampled_token_ids
.
flatten
()
# Update GPU tensors
ops
.
advance_step_flashinfer
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
model_input
.
input_tokens
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
,
paged_kv_indices
=
self
.
paged_kv_indices
,
paged_kv_indptr
=
self
.
paged_kv_indptr
,
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
,
block_table_bound
=
self
.
block_table_bound
)
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
...
...
@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
self
.
total_blocks
=
0
self
.
is_profile_run
:
bool
=
False
def
_add_seq_group
(
...
...
@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self
.
total_blocks
+=
len
(
block_table
)
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
...
...
@@ -583,6 +639,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
out
=
query_start_loc
[
1
:])
if
len
(
self
.
paged_kv_indptr
)
>
0
:
# extend to the maximum number of blocks as returned by the
# scheduler
self
.
paged_kv_indices
.
extend
(
[
0
]
*
(
self
.
total_blocks
-
len
(
self
.
paged_kv_indices
)))
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
...
...
@@ -591,10 +651,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_len
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
block_table_bound_tensor
=
torch
.
zeros
(
len
(
self
.
paged_kv_indptr
)
-
1
,
device
=
"cpu"
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
block_table_bound_tensor
=
None
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
...
@@ -613,6 +678,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
block_table_bound
=
block_table_bound_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
num_qo_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
),
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
...
...
vllm/worker/multi_step_model_runner.py
View file @
a6c0f365
...
...
@@ -4,13 +4,6 @@ from dataclasses import dataclass, field
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
)
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
except
ModuleNotFoundError
:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
import
torch
from
vllm.distributed
import
get_pp_group
...
...
@@ -36,6 +29,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
MULTI_STEP_ATTENTION_BACKENDS
=
[
"flash-attn"
,
"flashinfer"
]
def
seq_output_builder
():
return
SequenceOutput
(
...
...
@@ -489,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def
_advance_step
(
self
,
model_input
:
StatefulModelInput
,
out
:
SamplerOutput
)
->
StatefulModelInput
:
frozen_model_input
=
model_input
.
frozen_model_input
assert
frozen_model_input
is
not
None
assert
frozen_model_input
.
attn_metadata
is
not
None
if
self
.
attn_backend
.
get_name
()
not
in
MULTI_STEP_ATTENTION_BACKENDS
:
raise
ValueError
(
f
"Multi-step not supported for attention backend: "
f
"
{
self
.
attn_backend
.
get_name
()
}
. Set VLLM_ATTENTION_BACKEND "
f
"to a value from
{
MULTI_STEP_ATTENTION_BACKENDS
}
."
)
sampled_token_ids
=
model_input
.
cached_outputs
[
-
1
].
sampled_token_ids
num_seqs
=
model_input
.
num_seqs
num_queries
=
model_input
.
num_queries
assert
num_seqs
>
0
assert
num_queries
>
0
assert
num_seqs
>=
num_queries
frozen_model_input
=
model_input
.
frozen_model_input
assert
frozen_model_input
is
not
None
attn_metadata
=
frozen_model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
assert
attn_metadata
is
not
None
attn_metadata
.
advance_step
(
frozen_model_input
,
model_input
.
cached_outputs
[
-
1
].
sampled_token_ids
,
self
.
block_size
,
num_seqs
,
num_queries
)
if
frozen_model_input
.
seq_lens
is
not
None
:
for
i
in
range
(
num_queries
):
frozen_model_input
.
seq_lens
[
i
]
=
attn_metadata
.
seq_lens
[
i
]
sampled_token_ids
,
self
.
block_size
,
num_seqs
,
num_queries
,
)
return
model_input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment