Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e76466dd
Unverified
Commit
e76466dd
authored
Jul 17, 2024
by
Alexander Matveev
Committed by
GitHub
Jul 17, 2024
Browse files
[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338)
parent
5f0b9933
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
568 additions
and
130 deletions
+568
-130
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+5
-0
csrc/prepare_inputs/advance_step.cu
csrc/prepare_inputs/advance_step.cu
+131
-0
csrc/prepare_inputs/advance_step.cuh
csrc/prepare_inputs/advance_step.cuh
+19
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+1
-0
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+48
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+100
-47
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+10
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+225
-80
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+12
-3
No files found.
CMakeLists.txt
View file @
e76466dd
...
...
@@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
...
...
csrc/ops.h
View file @
e76466dd
...
...
@@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void
gelu_quick
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
);
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
...
...
csrc/prepare_inputs/advance_step.cu
0 → 100644
View file @
e76466dd
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/
#include "advance_step.cuh"
namespace
prepare_inputs
{
//
template
<
int
const
num_threads
>
__global__
void
advance_step_kernel
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
long
*
input_tokens_ptr
,
long
const
*
sampled_token_ids_ptr
,
long
*
input_positions_ptr
,
int
*
seq_lens_ptr
,
long
*
slot_mapping_ptr
,
int
const
*
block_tables_ptr
,
int64_t
const
block_tables_stride
)
{
int
num_query_blocks
=
div_ceil
(
num_queries
,
num_threads
);
if
(
blockIdx
.
x
>=
num_query_blocks
)
{
return
;
}
int
cur_query_id
=
blockIdx
.
x
*
num_threads
+
threadIdx
.
x
;
if
(
cur_query_id
>=
num_queries
)
{
return
;
}
// Update input_tokens
input_tokens_ptr
[
cur_query_id
]
=
sampled_token_ids_ptr
[
cur_query_id
];
int
seq_len
=
seq_lens_ptr
[
cur_query_id
];
int
next_seq_len
=
seq_len
+
1
;
int
next_input_pos
=
next_seq_len
-
1
;
// Update seq_lens
seq_lens_ptr
[
cur_query_id
]
=
next_seq_len
;
// Update input_positions
input_positions_ptr
[
cur_query_id
]
=
next_input_pos
;
int
const
*
seq_block_tables_ptr
=
block_tables_ptr
+
block_tables_stride
*
cur_query_id
;
int
block_index
=
next_input_pos
/
block_size
;
int
block_offset
=
next_input_pos
%
block_size
;
int
slot_num
=
seq_block_tables_ptr
[
block_index
]
*
block_size
+
block_offset
;
// Update slot_mapping
slot_mapping_ptr
[
cur_query_id
]
=
slot_num
;
}
inline
void
verify_tensor
(
std
::
string
const
&
name
,
torch
::
Tensor
&
t
,
int64_t
const
size_0
,
int64_t
const
size_1
,
c10
::
ScalarType
const
type
)
{
bool
size_0_cond
=
true
;
if
(
size_0
!=
-
1
)
{
size_0_cond
=
t
.
size
(
0
)
==
size_0
;
}
bool
size_1_cond
=
true
;
if
(
size_1
!=
-
1
)
{
size_1_cond
=
t
.
size
(
1
)
==
size_1
;
}
bool
is_contiguous
=
t
.
is_contiguous
();
bool
same_type
=
t
.
dtype
()
==
type
;
bool
pass
=
size_0_cond
&&
size_1_cond
&&
is_contiguous
&&
same_type
;
if
(
!
pass
)
{
TORCH_CHECK
(
false
,
"tensor: name = "
,
name
,
", shape = "
,
t
.
sizes
(),
" is_cont = "
,
t
.
is_contiguous
(),
", type = "
,
t
.
dtype
(),
" is not as expected: shape = ["
,
size_0
,
", "
,
size_1
,
"], type = "
,
type
);
}
}
void
advance_step
(
int
num_seqs
,
int
num_queries
,
int
block_size
,
torch
::
Tensor
&
input_tokens
,
// type: long
torch
::
Tensor
&
sampled_token_ids
,
// type: long
torch
::
Tensor
&
input_positions
,
// type: long
torch
::
Tensor
&
seq_lens
,
// type: int
torch
::
Tensor
&
slot_mapping
,
// type: long
torch
::
Tensor
&
block_tables
)
{
// type: int
if
(
logging
)
{
printf
(
"advance_step:
\n
"
);
printf
(
" num_seqs = %d
\n
"
,
num_seqs
);
printf
(
" num_queries = %d
\n
"
,
num_queries
);
printf
(
" block_size = %d
\n
"
,
block_size
);
}
// Verify all tensors
verify_tensor
(
"input_tokens"
,
input_tokens
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"sampled_token_ids"
,
sampled_token_ids
,
num_queries
,
1
,
at
::
kLong
);
verify_tensor
(
"input_positions"
,
input_positions
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"seq_lens"
,
seq_lens
,
num_seqs
,
-
1
,
at
::
kInt
);
verify_tensor
(
"slot_mapping"
,
slot_mapping
,
num_seqs
,
-
1
,
at
::
kLong
);
verify_tensor
(
"block_tables"
,
block_tables
,
num_seqs
,
-
1
,
at
::
kInt
);
int
dev
=
sampled_token_ids
.
get_device
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
dev
);
int
blocks
;
cudaDeviceGetAttribute
(
&
blocks
,
cudaDevAttrMultiProcessorCount
,
dev
);
advance_step_kernel
<
max_threads
><<<
blocks
,
max_threads
,
0
,
stream
>>>
(
num_seqs
,
num_queries
,
block_size
,
reinterpret_cast
<
long
*>
(
input_tokens
.
data_ptr
()),
reinterpret_cast
<
long
const
*>
(
sampled_token_ids
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
input_positions
.
data_ptr
()),
reinterpret_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
reinterpret_cast
<
long
*>
(
slot_mapping
.
data_ptr
()),
reinterpret_cast
<
int
const
*>
(
block_tables
.
data_ptr
()),
block_tables
.
stride
(
0
));
}
}
// namespace prepare_inputs
void
advance_step
(
int64_t
num_seqs
,
int64_t
num_queries
,
int64_t
block_size
,
torch
::
Tensor
&
input_tokens
,
torch
::
Tensor
&
sampled_token_ids
,
torch
::
Tensor
&
input_positions
,
torch
::
Tensor
&
seq_lens
,
torch
::
Tensor
&
slot_mapping
,
torch
::
Tensor
&
block_tables
)
{
prepare_inputs
::
advance_step
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
);
}
\ No newline at end of file
csrc/prepare_inputs/advance_step.cuh
0 → 100644
View file @
e76466dd
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace
prepare_inputs
{
static
constexpr
int
max_threads
=
256
;
static
constexpr
bool
logging
=
false
;
constexpr
int
div_ceil
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
}
// namespace prepare_inputs
csrc/torch_bindings.cpp
View file @
e76466dd
...
...
@@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gelu_quick(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"gelu_quick"
,
torch
::
kCUDA
,
&
gelu_quick
);
// prepare_inputs advance_step
ops
.
def
(
"advance_step"
,
&
advance_step
);
ops
.
impl
(
"advance_step"
,
torch
::
kCUDA
,
&
advance_step
);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
...
...
tests/spec_decode/e2e/conftest.py
View file @
e76466dd
...
...
@@ -227,6 +227,7 @@ def get_output_from_llm_generator(
maybe_assert_ngram_worker
(
llm
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
tokens
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
...
...
tests/spec_decode/test_multi_step_worker.py
View file @
e76466dd
...
...
@@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
assert
proposals
.
proposal_lens
.
tolist
()
==
[
k
for
_
in
range
(
expected_num_proposal_seqs
-
1
)
]
+
[
0
for
_
in
range
(
expected_num_no_proposal_seqs
)]
+
[
k
]
@
torch
.
inference_mode
()
def
test_use_draft_model_runner_advance_step
():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed
=
100
model_name
=
'JackFram/llama-68m'
k
=
5
batch_size
=
32
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
worker
=
create_worker
(
MultiStepWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret
=
"artificial stop"
worker
.
model_runner
.
_gpu_advance_step
=
MagicMock
()
worker
.
model_runner
.
_gpu_advance_step
.
side_effect
=
ValueError
(
exception_secret
)
seq_group_metadata_list
,
_
,
_
=
create_batch
(
batch_size
,
k
)
# Fallback (should not call) when num_steps=1.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
num_steps
=
1
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# Expect exception if _gpu_advance_step is called.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
num_lookahead_slots
=
k
,
num_steps
=
k
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
call_args_list
=
worker
.
model_runner
.
_gpu_advance_step
.
call_args_list
assert
len
(
call_args_list
)
==
1
vllm/_custom_ops.py
View file @
e76466dd
...
...
@@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
def
advance_step
(
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
)
->
None
:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return
torch
.
ops
.
_C
.
advance_step
(
num_seqs
,
num_queries
,
block_size
,
input_tokens
,
sampled_token_ids
,
input_positions
,
seq_lens
,
slot_mapping
,
block_tables
)
# quantization ops
# awq
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/sampler.py
View file @
e76466dd
...
...
@@ -47,6 +47,32 @@ class Sampler(nn.Module):
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
def
_init_sampling_tensors
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
):
"""The goal here is to reuse sampling tensors between similar decode
runs. This is possible because sampling logic does not change between
decodes of the same sequences.
"""
_
,
vocab_size
=
logits
.
shape
# First free any existing stored sampling tensors.
# This is necessary because some sampling tensors may
# have pinned memory.
self
.
_sampling_tensors
=
None
# Initialize new sampling tensors
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
self
.
_sampling_tensors
=
sampling_tensors
self
.
_do_penalties
=
do_penalties
self
.
_do_top_p_top_k
=
do_top_p_top_k
self
.
_do_min_p
=
do_min_p
def
forward
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -60,12 +86,23 @@ class Sampler(nn.Module):
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Prepare sampling tensors with pinned memory to avoid blocking.
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
sampling_metadata
,
vocab_size
,
logits
.
device
,
logits
.
dtype
)
if
not
sampling_metadata
.
reuse_sampling_tensors
:
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
elif
self
.
_do_penalties
:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
assert
self
.
_sampling_tensors
is
not
None
sampling_tensors
=
self
.
_sampling_tensors
do_penalties
=
self
.
_do_penalties
do_top_p_top_k
=
self
.
_do_top_p_top_k
do_min_p
=
self
.
_do_min_p
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Apply presence and frequency penalties.
if
do_penalties
:
...
...
@@ -77,7 +114,7 @@ class Sampler(nn.Module):
# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
_
(
dim
=
1
))
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
...
...
@@ -109,13 +146,19 @@ class Sampler(nn.Module):
on_device_tensors
=
None
# Get the logprobs query results.
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
logprobs
,
sampling_metadata
,
sample_results
)
return
_build_sampler_output
(
sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
)
prompt_logprobs
=
None
sample_logprobs
=
None
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
logprobs
,
sampling_metadata
,
sample_results
)
return
_build_sampler_output
(
sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
)
@
property
def
_should_modify_greedy_probs_inplace
(
self
)
->
bool
:
...
...
@@ -535,24 +578,29 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
elif
sampling_type
==
SamplingType
.
BEAM
:
sample_results
=
_beam_search_sample
(
seq_groups
,
beam_search_logprobs
)
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
else
:
sample_results
=
[]
sample_results
=
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
return
sample_results
,
sampled_token_ids_tensor
...
...
@@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def
_build_sampler_output
(
sample_results
:
SampleResultType
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
prompt_logprobs
:
Optional
[
List
[
Optional
[
PromptLogprobs
]]
]
,
sample_logprobs
:
Optional
[
List
[
SampleLogprobs
]
]
,
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
skip_sampler_cpu_output
:
bool
=
False
,
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
...
...
@@ -1010,22 +1059,26 @@ def _build_sampler_output(
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
:
List
[
SequenceOutput
]
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
next_token_ids
,
group_sample_logprobs
):
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
if
not
skip_sampler_cpu_output
:
assert
prompt_logprobs
is
not
None
assert
sample_logprobs
is
not
None
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
:
List
[
SequenceOutput
]
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
next_token_ids
,
group_sample_logprobs
):
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
...
...
vllm/model_executor/sampling_metadata.py
View file @
e76466dd
...
...
@@ -87,6 +87,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
def
__init__
(
...
...
@@ -95,11 +101,15 @@ class SamplingMetadata:
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
num_prompts
:
int
,
skip_sampler_cpu_output
:
bool
=
False
,
reuse_sampling_tensors
:
bool
=
False
,
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
num_prompts
=
num_prompts
self
.
skip_sampler_cpu_output
=
skip_sampler_cpu_output
self
.
reuse_sampling_tensors
=
reuse_sampling_tensors
@
staticmethod
def
prepare
(
...
...
vllm/spec_decode/draft_model_runner.py
View file @
e76466dd
...
...
@@ -2,17 +2,22 @@ from typing import List, Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
S
equenceGroupMetadata
)
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
S
amplerOutput
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
ModelRunner
)
logger
=
init_logger
(
__name__
)
debug_advance_input
=
False
enable_gpu_advance_step
=
True
class
TP1DraftModelRunner
(
ModelRunner
):
"""Specialized model runner for speculative decoding draft model.
...
...
@@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner):
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
...
...
@@ -71,51 +67,156 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states
=
return_hidden_states
,
)
# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self
.
cached_seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
num_queries
):
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForGPUWithSamplingMetadata
:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self
.
cached_seq_group_metadata_list
=
seq_group_metadata_list
return
super
().
prepare_model_input
(
seq_group_metadata_list
,
finished_requests_ids
=
finished_requests_ids
)
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
attn_metadata
.
use_cuda_graph
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_prefill_tokens
==
0
assert
attn_metadata
.
num_decode_tokens
==
num_seqs
assert
attn_metadata
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
len
(
attn_metadata
.
seq_lens
)
==
num_seqs
assert
attn_metadata
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
attn_metadata
.
max_query_len
==
1
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
attn_metadata
.
seq_lens
)
assert
attn_metadata
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
attn_metadata
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
attn_metadata
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
attn_metadata
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
attn_metadata
.
seq_lens
[
i
]
+=
1
attn_metadata
.
max_decode_seq_len
=
max
(
attn_metadata
.
seq_lens
)
def
update_model_input
(
def
_update_sampling_metadata
(
self
,
sampling_metadata
,
num_seqs
,
num_queries
):
assert
sampling_metadata
.
num_prompts
==
0
assert
len
(
sampling_metadata
.
seq_groups
)
==
num_queries
assert
sampling_metadata
.
selected_token_indices
.
shape
==
(
num_queries
,
)
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for
i
in
range
(
num_queries
):
seq_group
=
sampling_metadata
.
seq_groups
[
i
]
assert
seq_group
.
is_prompt
is
False
# No prompt
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
assert
seq_group
.
seq_len
is
None
# Decode
assert
seq_group
.
query_len
is
None
# Decode
def
_gpu_advance_step
(
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
last_output
:
SamplerOutput
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
# Currently, we expect "decode mode" only
assert
not
model_input
.
is_prompt
# Get num_seqs
num_seqs
=
len
(
model_input
.
seq_lens
)
num_queries
=
len
(
model_input
.
query_lens
)
# Get output tokens GPU tensor
sampled_token_ids
=
last_output
.
sampled_token_ids
assert
sampled_token_ids
is
not
None
# Update attn_metadata
attn_metadata
=
model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
self
.
_update_flash_attn_metadata
(
attn_metadata
,
num_seqs
,
num_queries
)
# Update GPU tensors
ops
.
advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
self
.
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
attn_metadata
.
seq_lens_tensor
,
slot_mapping
=
attn_metadata
.
slot_mapping
,
block_tables
=
attn_metadata
.
block_tables
)
# Update sampling_metadata
sampling_metadata
=
model_input
.
sampling_metadata
self
.
_update_sampling_metadata
(
sampling_metadata
,
num_seqs
,
num_queries
)
# Create new input
new_model_input
=
self
.
_model_input_cls
(
input_tokens
=
model_input
.
input_tokens
,
input_positions
=
model_input
.
input_positions
,
attn_metadata
=
attn_metadata
,
seq_lens
=
attn_metadata
.
seq_lens
,
query_lens
=
model_input
.
query_lens
,
lora_mapping
=
model_input
.
lora_mapping
,
lora_requests
=
model_input
.
lora_requests
,
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
,
sampling_metadata
=
model_input
.
sampling_metadata
,
is_prompt
=
False
,
)
# Ensure we skip CPU samples
assert
new_model_input
.
sampling_metadata
.
skip_sampler_cpu_output
is
True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input
.
sampling_metadata
.
reuse_sampling_tensors
=
True
if
debug_advance_input
:
logger
.
debug
(
"NEW INPUT: "
)
logger
.
debug
(
" input_tokens = %s"
,
new_model_input
.
input_tokens
)
logger
.
debug
(
" input_positions = %s"
,
new_model_input
.
input_positions
)
logger
.
debug
(
" seq_lens = %d"
,
new_model_input
.
seq_lens
)
logger
.
debug
(
" query_lens = %d"
,
new_model_input
.
query_lens
)
logger
.
debug
(
" attn_metadata:"
)
logger
.
debug
(
" seq_lens_tensor: %s"
,
attn_metadata
.
seq_lens_tensor
)
logger
.
debug
(
" slot_mapping: %s"
,
attn_metadata
.
slot_mapping
)
logger
.
debug
(
" block_tables: %s"
,
attn_metadata
.
block_tables
)
return
new_model_input
def
supports_gpu_multi_step
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if
not
enable_gpu_advance_step
:
return
False
# Append the output token to the sequence data.
assert
self
.
cached_seq_group_metadata_list
is
not
None
for
seq_group_metadata
,
sequence_group_outputs
in
zip
(
self
.
cached_seq_group_metadata_list
,
last_output
.
outputs
):
seq_group_metadata
.
is_prompt
=
False
# We allow multi-step GPU only in decode mode
for
seq_group
in
execute_model_req
.
seq_group_metadata_list
:
if
seq_group
.
is_prompt
:
return
False
for
seq_output
in
sequence_group_outputs
.
samples
:
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
# TODO: Add support for other attn backends
if
self
.
attn_backend
.
get_name
()
!=
"flash-attn"
:
return
False
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
# TODO: Add support for LORA
if
self
.
lora_config
:
return
False
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
update_num_computed_tokens
(
1
)
# TODO: Add soft-tuning prompt adapter support
if
self
.
prompt_adapter_config
:
return
False
return
self
.
prepare_model_input
(
self
.
cached_seq_group_metadata_list
)
return
True
@
torch
.
inference_mode
()
def
execute_model
(
...
...
@@ -125,42 +226,86 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if
not
self
.
is_driver_worker
:
raise
ValueError
(
"TP1DraftModelRunner only supports TP=1."
)
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
if
self
.
prompt_adapter_config
:
assert
model_input
.
prompt_adapter_requests
is
not
None
assert
model_input
.
prompt_adapter_mapping
is
not
None
self
.
set_active_prompt_adapters
(
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback
=
num_steps
==
1
if
not
is_fallback
:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if
not
self
.
is_driver_worker
:
raise
ValueError
(
"TP1DraftModelRunner only supports TP=1."
)
# Sanity
if
self
.
lora_config
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for LORA"
)
if
self
.
prompt_adapter_config
is
not
None
:
raise
ValueError
(
"TP1DraftModelRunner has no support for "
"prompt_adapter_config"
)
if
model_input
.
multi_modal_kwargs
:
raise
ValueError
(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else
:
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
self
.
prompt_adapter_config
:
assert
model_input
.
prompt_adapter_requests
is
not
None
assert
model_input
.
prompt_adapter_mapping
is
not
None
self
.
set_active_prompt_adapters
(
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
use_cuda_graph
=
False
if
model_input
.
attn_metadata
.
num_prefills
>
0
:
# In this case, execute_model(..) was called directly
if
num_steps
>
1
:
raise
ValueError
(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill"
)
else
:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input
.
sampling_metadata
.
skip_sampler_cpu_output
=
(
not
is_fallback
)
# Attn attr defines if we use cuda graphs
use_cuda_graph
=
model_input
.
attn_metadata
.
use_cuda_graph
# Get model
if
use_cuda_graph
:
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
model_input
.
virtual_engine
]
[
graph_batch_size
])
else
:
model_executable
=
self
.
model
virtual_engine
=
model_input
.
virtual_engine
outputs
:
List
[
SamplerOutput
]
=
[]
for
step
in
range
(
num_steps
):
# Currently cuda graph is only supported by the decode phase.
assert
model_input
.
attn_metadata
is
not
None
prefill_meta
=
model_input
.
attn_metadata
.
prefill_metadata
decode_meta
=
model_input
.
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
(
self
.
graph_runners
[
virtual_engine
][
graph_batch_size
])
else
:
model_executable
=
self
.
model
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
# Run model
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
@@ -181,8 +326,8 @@ class TP1DraftModelRunner(ModelRunner):
sampling_metadata
=
model_input
.
sampling_metadata
,
))
# Prepare
the
inputs for the next step
.
# Prepare inputs for the next step
if
step
!=
num_steps
-
1
:
model_input
=
self
.
update_model_input
(
model_input
,
outputs
[
-
1
])
model_input
=
self
.
_gpu_advance_step
(
model_input
,
outputs
[
-
1
])
return
outputs
vllm/spec_decode/multi_step_worker.py
View file @
e76466dd
...
...
@@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
expanded_request
,
indices_of_seq_with_bonus_tokens
=
\
self
.
_expand_execute_model_request
(
execute_model_req
,
seq_ids_with_bonus_token_in_last_step
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
):
if
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request
.
num_steps
=
sample_len
model_outputs
=
self
.
execute_model
(
execute_model_req
=
expanded_request
)
else
:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
super
().
execute_model
(
execute_model_req
=
expanded_request
)
...
...
@@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
outputs
=
[
expanded_batch_output
.
outputs
[
i
]
for
i
in
output_indices_to_retain
],
]
if
len
(
expanded_batch_output
.
outputs
)
>
0
else
[]
,
sampled_token_probs
=
(
expanded_batch_output
.
sampled_token_probs
[
output_indices_to_retain
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment