Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
025a32f9
Unverified
Commit
025a32f9
authored
Jan 11, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 11, 2026
Browse files
[Model Runner V2] Remove async barrier (#32083)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
19504ac0
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
590 additions
and
462 deletions
+590
-462
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+18
-34
vllm/v1/worker/gpu/buffer_utils.py
vllm/v1/worker/gpu/buffer_utils.py
+218
-0
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+9
-6
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+57
-25
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+108
-111
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+12
-7
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+14
-127
vllm/v1/worker/gpu/sample/min_p.py
vllm/v1/worker/gpu/sample/min_p.py
+9
-2
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+7
-7
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+2
-1
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+27
-23
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+51
-85
vllm/v1/worker/gpu/structured_outputs.py
vllm/v1/worker/gpu/structured_outputs.py
+58
-34
No files found.
vllm/v1/worker/gpu/block_table.py
View file @
025a32f9
...
...
@@ -6,9 +6,8 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
class
BlockTables
:
...
...
@@ -26,19 +25,16 @@ class BlockTables:
self
.
max_model_len
=
max_model_len
self
.
device
=
device
if
not
is_uva_available
():
raise
RuntimeError
(
"UVA is not available"
)
self
.
num_kv_cache_groups
=
len
(
self
.
block_sizes
)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self
.
block_tables
:
list
[
UvaBuffe
r
]
=
[]
self
.
block_tables
:
list
[
StagedWriteTenso
r
]
=
[]
for
i
in
range
(
self
.
num_kv_cache_groups
):
block_size
=
self
.
block_sizes
[
i
]
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
block_size
)
block_table
=
UvaBuffer
(
self
.
max_num_reqs
,
max_num_blocks
,
block_table
=
StagedWriteTensor
(
(
self
.
max_num_reqs
,
max_num_blocks
),
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
block_tables
.
append
(
block_table
)
self
.
block_table_ptrs
=
self
.
_make_ptr_tensor
(
...
...
@@ -53,9 +49,8 @@ class BlockTables:
self
.
block_sizes_tensor
=
torch
.
tensor
(
self
.
block_sizes
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
num_blocks
=
UvaBuffer
(
self
.
num_kv_cache_groups
,
self
.
max_num_reqs
,
self
.
num_blocks
=
UvaBackedTensor
(
(
self
.
num_kv_cache_groups
,
self
.
max_num_reqs
),
dtype
=
torch
.
int32
,
)
...
...
@@ -75,13 +70,11 @@ class BlockTables:
def
_make_ptr_tensor
(
self
,
x
:
Iterable
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu
=
torch
.
tensor
(
return
torch
.
tensor
(
[
t
.
data_ptr
()
for
t
in
x
],
dtype
=
torch
.
uint64
,
device
=
"cpu"
,
pin_memory
=
True
,
device
=
self
.
device
,
)
return
ptrs_tensor_cpu
.
to
(
self
.
device
,
non_blocking
=
True
)
def
append_block_ids
(
self
,
...
...
@@ -90,19 +83,17 @@ class BlockTables:
overwrite
:
bool
,
)
->
None
:
for
i
in
range
(
self
.
num_kv_cache_groups
):
start
=
self
.
num_blocks
.
np
[
i
,
req_index
]
if
not
overwrite
else
0
block_ids
=
new_block_ids
[
i
]
num_new_blocks
=
len
(
block_ids
)
if
num_new_blocks
==
0
:
continue
self
.
block_tables
[
i
].
stage_write
(
req_index
,
start
,
block_ids
)
self
.
num_blocks
.
np
[
i
,
req_index
]
=
start
+
len
(
block_ids
)
# TODO(woosuk): Too many Numpy invocations. Optimize this.
start
=
self
.
num_blocks
.
np
[
i
,
req_index
]
if
not
overwrite
else
0
end
=
start
+
num_new_blocks
if
num_new_blocks
==
1
:
self
.
block_tables
[
i
].
np
[
req_index
,
start
]
=
block_ids
[
0
]
else
:
self
.
block_tables
[
i
].
np
[
req_index
,
start
:
end
]
=
block_ids
self
.
num_blocks
.
np
[
i
,
req_index
]
=
end
def
apply_staged_writes
(
self
)
->
None
:
# TODO(woosuk): This can be inefficient since it launches one kernel per
# block table. Implement a kernel to handle all block tables at once.
for
block_table
in
self
.
block_tables
:
block_table
.
apply_write
()
self
.
num_blocks
.
copy_to_uva
()
def
gather_block_tables
(
self
,
...
...
@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
ptr
=
tl
.
load
(
ptr_to_ptr
)
ptr
=
tl
.
cast
(
ptr
,
tl
.
pointer_type
(
elem_dtype
))
return
tl
.
multiple_of
(
ptr
,
16
)
class
UvaBuffer
:
def
__init__
(
self
,
*
size
,
dtype
:
torch
.
dtype
):
self
.
cpu
=
torch
.
zeros
(
*
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
vllm/v1/worker/gpu/buffer_utils.py
0 → 100644
View file @
025a32f9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
import
numpy
as
np
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
class
UvaBuffer
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
):
if
not
is_uva_available
():
raise
RuntimeError
(
"UVA is not available"
)
self
.
cpu
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
uva
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
class
UvaBufferPool
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
,
max_concurrency
:
int
=
2
,
):
self
.
size
=
size
self
.
dtype
=
dtype
self
.
max_concurrency
=
max_concurrency
# UVA buffers for concurrency
self
.
_uva_bufs
=
[
UvaBuffer
(
size
,
dtype
)
for
_
in
range
(
max_concurrency
)]
# Current buffer index
self
.
_curr
=
0
def
copy_to_uva
(
self
,
x
:
torch
.
Tensor
|
np
.
ndarray
|
list
)
->
torch
.
Tensor
:
# Round robin to the next buffer.
self
.
_curr
=
(
self
.
_curr
+
1
)
%
self
.
max_concurrency
buf
=
self
.
_uva_bufs
[
self
.
_curr
]
# CPU-to-CPU copy
dst
=
buf
.
cpu
if
isinstance
(
x
,
torch
.
Tensor
)
else
buf
.
np
n
=
len
(
x
)
dst
[:
n
]
=
x
return
buf
.
uva
[:
n
]
def
copy_to_gpu
(
self
,
x
:
torch
.
Tensor
|
np
.
ndarray
,
out
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
uva
=
self
.
copy_to_uva
(
x
)
if
out
is
None
:
# CPU-to-GPU copy
return
uva
.
clone
()
# CPU-to-GPU copy
return
out
.
copy_
(
uva
,
non_blocking
=
True
)
class
UvaBackedTensor
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
,
max_concurrency
:
int
=
2
,
):
self
.
dtype
=
dtype
self
.
max_concurrency
=
max_concurrency
# Source of truth
self
.
cpu
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
False
)
self
.
np
=
self
.
cpu
.
numpy
()
# Buffers for concurrency
self
.
pool
=
UvaBufferPool
(
size
,
dtype
,
max_concurrency
)
self
.
gpu
=
self
.
pool
.
copy_to_uva
(
self
.
np
)
def
copy_to_uva
(
self
,
n
:
int
|
None
=
None
)
->
torch
.
Tensor
:
# CPU-to-CPU copy
self
.
gpu
=
self
.
pool
.
copy_to_uva
(
self
.
np
[:
n
]
if
n
is
not
None
else
self
.
np
)
return
self
.
gpu
class
StagedWriteTensor
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
max_concurrency
:
int
=
2
,
uva_instead_of_gpu
:
bool
=
False
,
):
if
dtype
not
in
[
torch
.
int32
,
torch
.
int64
]:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
: should be either int32 or int64"
)
self
.
num_rows
=
size
if
isinstance
(
size
,
int
)
else
size
[
0
]
self
.
dtype
=
dtype
self
.
max_concurrency
=
max_concurrency
if
not
uva_instead_of_gpu
:
# Create a GPU tensor (default)
self
.
gpu
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
device
)
else
:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self
.
_uva_buf
=
UvaBuffer
(
size
,
dtype
)
self
.
gpu
=
self
.
_uva_buf
.
uva
self
.
_staged_write_indices
:
list
[
int
]
=
[]
self
.
_staged_write_starts
:
list
[
int
]
=
[]
self
.
_staged_write_contents
:
list
[
int
]
=
[]
self
.
_staged_write_cu_lens
:
list
[
int
]
=
[]
self
.
write_indices
=
UvaBufferPool
(
self
.
num_rows
,
dtype
=
torch
.
int32
,
max_concurrency
=
max_concurrency
)
self
.
write_starts
=
UvaBufferPool
(
self
.
num_rows
,
dtype
=
torch
.
int32
,
max_concurrency
=
max_concurrency
)
init_size
=
next_power_of_2
(
self
.
num_rows
)
self
.
write_contents
=
UvaBufferPool
(
init_size
,
dtype
=
dtype
,
max_concurrency
=
max_concurrency
)
self
.
write_cu_lens
=
UvaBufferPool
(
self
.
num_rows
,
dtype
=
torch
.
int32
,
max_concurrency
=
max_concurrency
)
def
stage_write
(
self
,
index
:
int
,
start
:
int
,
x
:
list
[
int
])
->
None
:
assert
index
>=
0
assert
start
>=
0
if
not
x
:
return
self
.
_staged_write_indices
.
append
(
index
)
self
.
_staged_write_starts
.
append
(
start
)
self
.
_staged_write_contents
.
extend
(
x
)
self
.
_staged_write_cu_lens
.
append
(
len
(
self
.
_staged_write_contents
))
def
stage_write_elem
(
self
,
index
:
int
,
x
:
int
)
->
None
:
assert
index
>=
0
self
.
_staged_write_indices
.
append
(
index
)
self
.
_staged_write_starts
.
append
(
0
)
self
.
_staged_write_contents
.
append
(
x
)
self
.
_staged_write_cu_lens
.
append
(
len
(
self
.
_staged_write_contents
))
def
apply_write
(
self
)
->
None
:
n
=
len
(
self
.
_staged_write_indices
)
if
n
==
0
:
return
indices_uva
=
self
.
write_indices
.
copy_to_uva
(
self
.
_staged_write_indices
)
starts_uva
=
self
.
write_starts
.
copy_to_uva
(
self
.
_staged_write_starts
)
cu_lens_uva
=
self
.
write_cu_lens
.
copy_to_uva
(
self
.
_staged_write_cu_lens
)
# Special handling for write_contents
diff_len
=
len
(
self
.
_staged_write_contents
)
assert
isinstance
(
self
.
write_contents
.
size
,
int
)
if
diff_len
>
self
.
write_contents
.
size
:
# Re-allocate a larger buffer for the write_contents
new_size
=
next_power_of_2
(
diff_len
)
self
.
write_contents
=
UvaBufferPool
(
new_size
,
dtype
=
self
.
dtype
,
max_concurrency
=
self
.
max_concurrency
)
# NOTE(woosuk): Since the previous write_contents buffer is released,
# we perform a synchronization here to ensure that all data transfers
# involving the old buffer have finished before allocating a new one.
# This prevents potential race conditions. The slight overhead is
# negligible because the reallocations are infrequent in practice.
torch
.
cuda
.
synchronize
()
contents_uva
=
self
.
write_contents
.
copy_to_uva
(
self
.
_staged_write_contents
)
# Write diffs to the GPU buffer
_apply_write_kernel
[(
n
,)](
self
.
gpu
,
self
.
gpu
.
stride
(
0
),
indices_uva
,
starts_uva
,
contents_uva
,
cu_lens_uva
,
BLOCK_SIZE
=
1024
,
)
# Clear the staged writes
self
.
clear_staged_writes
()
def
clear_staged_writes
(
self
)
->
None
:
self
.
_staged_write_indices
.
clear
()
self
.
_staged_write_starts
.
clear
()
self
.
_staged_write_contents
.
clear
()
self
.
_staged_write_cu_lens
.
clear
()
@
triton
.
jit
def
_apply_write_kernel
(
output_ptr
,
output_stride
,
write_indices_ptr
,
write_starts_ptr
,
write_contents_ptr
,
write_cu_lens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
row_idx
=
tl
.
load
(
write_indices_ptr
+
pid
)
start_idx
=
tl
.
load
(
write_starts_ptr
+
pid
)
cu_start
=
tl
.
load
(
write_cu_lens_ptr
+
pid
-
1
)
if
pid
>
0
else
0
cu_end
=
tl
.
load
(
write_cu_lens_ptr
+
pid
)
content_len
=
cu_end
-
cu_start
for
i
in
range
(
0
,
content_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
content_len
content
=
tl
.
load
(
write_contents_ptr
+
cu_start
+
block
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
row_idx
*
output_stride
+
start_idx
+
block
,
content
,
mask
=
mask
)
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
025a32f9
...
...
@@ -228,10 +228,13 @@ def prepare_inputs_to_capture(
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
num_tokens_per_req
=
num_tokens
//
num_reqs
query_start_loc
=
input_buffers
.
query_start_loc
query_start_loc
.
np
[:
num_reqs
+
1
]
=
np
.
arange
(
num_reqs
+
1
)
*
num_tokens_per_req
query_start_loc
.
np
[
num_reqs
:]
=
num_tokens
query_start_loc
.
copy_to_gpu
()
query_start_loc_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
*
num_tokens_per_req
query_start_loc_np
[
-
1
]
=
num_tokens
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
=
query_start_loc_cpu
input_buffers
.
query_start_loc
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc
=
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
...
...
@@ -245,8 +248,8 @@ def prepare_inputs_to_capture(
attn_metadata_builders
=
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
num_reqs
+
1
]
,
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
num_reqs
+
1
]
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc
_
cpu
,
seq_lens
=
input_buffers
.
seq_lens
,
max_seq_len
=
max_model_len
,
block_tables
=
input_block_tables
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
025a32f9
...
...
@@ -8,8 +8,6 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
random_uuid
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.utils
import
CpuGpuBuffer
class
InputBuffers
:
...
...
@@ -21,30 +19,17 @@ class InputBuffers:
vocab_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_tokens
=
max_num_tokens
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
idx_mapping
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
# Structured outputs.
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
grammar_bitmask
=
self
.
_make_buffer
(
max_num_reqs
,
cdiv
(
vocab_size
,
32
),
dtype
=
torch
.
int32
)
def
_make_buffer
(
self
,
*
args
,
dtype
:
torch
.
dtype
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
*
args
,
dtype
=
dtype
,
pin_memory
=
self
.
pin_memory
,
device
=
self
.
device
self
.
query_start_loc
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
@
dataclass
...
...
@@ -56,6 +41,8 @@ class InputBatch:
# batch_idx -> req_state_idx
idx_mapping
:
torch
.
Tensor
idx_mapping_np
:
np
.
ndarray
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping
:
torch
.
Tensor
# [num_reqs]
# batch_idx -> num_scheduled_tokens
...
...
@@ -83,6 +70,7 @@ class InputBatch:
logits_indices
:
torch
.
Tensor
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
cu_num_logits_np
:
np
.
ndarray
@
classmethod
def
make_dummy
(
...
...
@@ -96,33 +84,41 @@ class InputBatch:
req_ids
=
[
f
"req_
{
i
}
_
{
random_uuid
()
}
"
for
i
in
range
(
num_reqs
)]
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
expanded_idx_mapping
=
idx_mapping
num_scheduled_tokens
=
np
.
full
(
num_reqs
,
num_tokens
//
num_reqs
,
dtype
=
np
.
int32
)
num_scheduled_tokens
[
-
1
]
+=
num_tokens
%
num_reqs
assert
int
(
num_scheduled_tokens
.
sum
())
==
num_tokens
input_buffers
.
query_start_loc
.
np
[
0
]
=
0
input_buffers
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
np
.
cumsum
(
num_scheduled_tokens
)
input_buffers
.
query_start_loc
.
np
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc_np
=
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
query_start_loc
=
input_buffers
.
query_start_loc
.
copy_to_gpu
()[:
num_reqs
+
1
]
# seq_len equals to query_len
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
//
num_reqs
input_buffers
.
seq_lens
[
num_reqs
-
1
]
+=
num_tokens
%
num_reqs
# Pad for full CUDA graph mode.
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
seq_lens
=
input_buffers
.
seq_lens
[:
num_reqs
]
query_start_loc_np
=
np
.
empty
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
[
0
]
=
0
np
.
cumsum
(
num_scheduled_tokens
,
out
=
query_start_loc_np
[
1
:])
input_buffers
.
query_start_loc
[
0
]
=
0
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
out
=
input_buffers
.
query_start_loc
[
1
:
num_reqs
+
1
]
)
# Pad for full CUDA graph mode.
input_buffers
.
query_start_loc
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc
=
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
# attn_metadata = defaultdict(lambda: None)
logits_indices
=
query_start_loc
[
1
:]
-
1
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_num_logits_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
return
cls
(
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
...
...
@@ -135,6 +131,7 @@ class InputBatch:
attn_metadata
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
)
...
...
@@ -473,3 +470,38 @@ def post_update(
query_start_loc
,
num_warps
=
1
,
)
@
triton
.
jit
def
_expand_idx_mapping_kernel
(
idx_mapping_ptr
,
expanded_idx_mapping_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
tl
.
store
(
expanded_idx_mapping_ptr
+
start_idx
+
block
,
req_state_idx
,
mask
=
mask
)
def
expand_idx_mapping
(
idx_mapping
:
torch
.
Tensor
,
total_num_logits
:
int
,
cu_num_logits
:
torch
.
Tensor
,
max_expand_len
:
int
,
)
->
torch
.
Tensor
:
num_reqs
=
idx_mapping
.
shape
[
0
]
expanded_idx_mapping
=
idx_mapping
.
new_empty
(
total_num_logits
)
_expand_idx_mapping_kernel
[(
num_reqs
,)](
idx_mapping
,
expanded_idx_mapping
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
)
return
expanded_idx_mapping
vllm/v1/worker/gpu/model_runner.py
View file @
025a32f9
...
...
@@ -15,7 +15,6 @@ from vllm.forward_context import set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
...
@@ -24,7 +23,7 @@ from vllm.v1.outputs import (
LogprobsTensors
,
ModelRunnerOutput
,
)
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
,
async_barrier
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
get_kv_cache_spec
,
...
...
@@ -32,6 +31,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_kv_cache
,
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.dp_utils
import
(
get_batch_metadata_across_dp
,
...
...
@@ -41,22 +41,20 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch
,
InputBuffers
,
combine_sampled_and_draft_tokens
,
expand_idx_mapping
,
get_num_sampled_and_rejected
,
post_update
,
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
)
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
(
SamplingMetadata
,
expand_sampling_metadata
,
)
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.gpu.structured_outputs
import
StructuredOutputsWorker
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
...
@@ -81,7 +79,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
observability_config
=
vllm_config
.
observability_config
self
.
device
=
device
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
kv_cache_dtype
=
self
.
dtype
if
self
.
cache_config
.
cache_dtype
!=
"auto"
:
...
...
@@ -123,7 +120,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_speculative_steps
=
self
.
num_speculative_steps
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
self
.
input_buffers
=
InputBuffers
(
max_num_reqs
=
self
.
max_num_reqs
,
...
...
@@ -132,12 +128,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
device
)
# Structured outputs worker.
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
vocab_size
=
self
.
vocab_size
,
)
# Buffers for CPU-to-GPU copies.
self
.
tmp_idx_mapping
=
UvaBufferPool
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
tmp_cu_num_logits
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
tmp_query_start_loc
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
...
...
@@ -228,16 +233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
input_batch
.
num_tokens
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
input_batch
.
num_reqs
+
1
]
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
input_batch
.
num_reqs
+
1
]
attn_metadata
=
build_attn_metadata
(
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
query_start_loc
_gpu
,
query_start_loc_cpu
=
query_start_loc_
cpu
,
seq_lens
=
self
.
input_b
uffers
.
seq_lens
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_
np
)
,
seq_lens
=
input_b
atch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
...
...
@@ -396,8 +398,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
prefill_len
.
copy_to_gpu
()
# Add new blocks for the existing requests.
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
...
...
@@ -409,6 +409,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index
,
req_new_block_ids
,
overwrite
=
False
)
self
.
req_states
.
apply_staged_writes
()
self
.
block_tables
.
apply_staged_writes
()
def
prepare_inputs
(
self
,
scheduler_output
:
SchedulerOutput
,
...
...
@@ -431,19 +434,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_list
=
[
self
.
req_states
.
req_id_to_index
[
req_id
]
for
req_id
in
req_ids
]
idx_mapping
=
self
.
input_buffers
.
idx_mapping
idx_mapping
.
np
[:
num_reqs
]
=
idx_mapping_list
idx_mapping_np
=
idx_mapping
.
np
[:
num_reqs
]
idx_mapping
=
idx_mapping
.
copy_to_gpu
(
num_reqs
)
idx_mapping_np
=
np
.
array
(
idx_mapping_list
,
dtype
=
np
.
int32
)
idx_mapping
=
self
.
tmp_idx_mapping
.
copy_to_gpu
(
idx_mapping_np
)
# Get the number of draft tokens for each request.
if
not
scheduler_output
.
scheduled_spec_decode_tokens
:
# No draft token scheduled (common case).
total_num_draft_tokens
=
0
total_num_logits
=
num_reqs
cu_num_logits_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
expanded_idx_mapping
=
idx_mapping
else
:
draft_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
num_draft_tokens
=
np
.
array
(
...
...
@@ -456,44 +459,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
total_num_draft_tokens
=
int
(
num_draft_tokens
.
sum
())
total_num_logits
=
num_reqs
+
total_num_draft_tokens
np
.
cumsum
(
num_draft_tokens
+
1
,
out
=
self
.
input_buffers
.
cu_num_logits
.
np
[
1
:
num_reqs
+
1
],
num_logits
=
num_draft_tokens
+
1
cu_num_logits_np
=
np
.
empty
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
cu_num_logits_np
[
0
]
=
0
np
.
cumsum
(
num_logits
,
out
=
cu_num_logits_np
[
1
:])
cu_num_logits
=
self
.
tmp_cu_num_logits
.
copy_to_gpu
(
cu_num_logits_np
)
expanded_idx_mapping
=
expand_idx_mapping
(
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
=
self
.
num_speculative_steps
+
1
,
)
cu_num_logits
=
self
.
input_buffers
.
cu_num_logits
.
copy_to_gpu
(
num_reqs
+
1
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
# Get query_start_loc.
np
.
cumsum
(
num_scheduled_tokens
,
out
=
self
.
input_buffers
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
],
)
query_start_loc_np
=
np
.
empty
(
self
.
max_num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
[
0
]
=
0
np
.
cumsum
(
num_scheduled_tokens
,
out
=
query_start_loc_np
[
1
:
num_reqs
+
1
])
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
self
.
input_buffers
.
query_start_loc
.
np
[
num_reqs
+
1
:]
=
num_tokens
self
.
input_buffers
.
query_start_loc
.
copy_to_gpu
()
query_start_loc_gpu
=
self
.
input_buffers
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
query_start_loc_cpu
=
self
.
input_buffers
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
query_start_loc_np
=
self
.
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
self
.
tmp_query_start_loc
.
copy_to_gpu
(
query_start_loc_np
,
out
=
self
.
input_buffers
.
query_start_loc
,
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
# Get prefill tokens.
prepare_prefill_inputs
(
self
.
input_buffers
.
input_ids
,
self
.
req_states
.
next_prefill_tokens
,
idx_mapping
,
query_start_loc
_gpu
,
query_start_loc
,
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
)
# Prepare positions and seq_lens.
prepare_pos_seq_lens
(
idx_mapping
,
query_start_loc
_gpu
,
self
.
req_states
.
num_computed_tokens
,
query_start_loc
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
input_buffers
.
positions
,
self
.
input_buffers
.
seq_lens
,
)
...
...
@@ -505,7 +517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
input_buffers
.
input_ids
,
idx_mapping
,
self
.
req_states
.
last_sampled_tokens
,
query_start_loc
_gpu
,
query_start_loc
,
seq_lens
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
draft_tokens
,
...
...
@@ -515,7 +527,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
_gpu
,
self
.
input_buffers
.
positions
[:
num_tokens
]
query_start_loc
,
self
.
input_buffers
.
positions
[:
num_tokens
]
)
# Layer name -> attention metadata.
...
...
@@ -523,7 +535,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
_gpu
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
input_buffers
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
...
...
@@ -539,11 +551,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs
=
num_reqs
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens_after_padding
,
num_draft_tokens
=
total_num_draft_tokens
,
query_start_loc
=
query_start_loc
_gpu
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
...
...
@@ -551,6 +564,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata
=
attn_metadata
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
)
def
sample
(
...
...
@@ -564,15 +578,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
assert
input_batch
.
num_draft_tokens
==
0
with
async_barrier
(
self
.
structured_outputs_event
):
apply_grammar_bitmask
(
self
.
structured_outputs_worker
.
apply_grammar_bitmask
(
logits
,
input_batch
.
req_ids
,
input_batch
,
grammar_output
.
structured_output_request_ids
,
grammar_output
.
grammar_bitmask
,
self
.
input_buffers
,
)
# Sample tokens and compute logprobs (if needed).
...
...
@@ -641,8 +651,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Handle chunked prompts.
pos_after_step
=
computed_prefill
+
input_batch
.
num_scheduled_tokens
is_prompt_chunked
=
pos_after_step
<
prompt_lens
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
.
np
query_start_loc
=
self
.
input_b
uffers
.
query_start_loc
.
np
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
.
gpu
query_start_loc
_np
=
input_b
atch
.
query_start_loc
_
np
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
if
not
needs_prompt_logprobs
[
i
]:
continue
...
...
@@ -650,10 +660,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
# The prompt is chunked. Get the next prompt token.
req_idx
=
input_batch
.
idx_mapping_np
[
i
]
next_prompt_token
=
int
(
prefill_token_ids
[
req_idx
,
pos_after_step
[
i
]])
idx
=
int
(
query_start_loc
[
i
+
1
]
-
1
)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
idx
=
int
(
query_start_loc_np
[
i
+
1
]
-
1
)
# NOTE(woosuk): This triggers two GPU operations.
next_prompt_token
=
prefill_token_ids
[
req_idx
,
pos_after_step
[
i
]]
token_ids
[
idx
]
=
next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
...
...
@@ -669,8 +678,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
not
needs_prompt_logprobs
[
i
]:
continue
start_idx
=
query_start_loc
[
i
]
end_idx
=
query_start_loc
[
i
+
1
]
start_idx
=
query_start_loc
_np
[
i
]
end_idx
=
query_start_loc
_np
[
i
+
1
]
assert
start_idx
<
end_idx
,
(
f
"start_idx (
{
start_idx
}
) >= end_idx (
{
end_idx
}
)"
)
...
...
@@ -714,7 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the number of computed tokens.
post_update
(
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
output_bin_counts
,
sampled_tokens
,
...
...
@@ -825,16 +834,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert
intermediate_tensors
is
None
if
scheduler_output
.
total_num_scheduled_tokens
==
0
and
not
dummy_run
:
# No need to run the model.
with
async_barrier
(
self
.
input_prep_event
):
self
.
update_states
(
scheduler_output
)
return
EMPTY_MODEL_RUNNER_OUTPUT
# NOTE: Call this before the async barrier so CPU all-reduce and
# GPU execution can overlap.
cudagraph_mode
,
num_tokens_after_padding
,
num_tokens_across_dp
=
(
self
.
get_cudagraph_and_dp_padding
(
scheduler_output
)
)
with
async_barrier
(
self
.
input_prep_event
):
self
.
update_states
(
scheduler_output
)
if
num_tokens_after_padding
==
0
:
# All DP ranks have zero tokens to run.
...
...
@@ -848,17 +853,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_after_padding
,
)
# NOTE(woosuk): Sampling metadata should be built under the async
# barrier to avoid race conditions.
pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
input_batch
.
num_draft_tokens
>
0
:
sampling_metadata
=
expand_sampling_metadata
(
sampling_metadata
,
input_batch
.
cu_num_logits
,
max_expand_len
=
self
.
num_speculative_steps
+
1
,
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
self
.
lora_config
:
...
...
vllm/v1/worker/gpu/sample/gumbel.py
View file @
025a32f9
...
...
@@ -13,6 +13,7 @@ def _gumbel_sample_kernel(
local_max_stride
,
logits_ptr
,
logits_stride
,
idx_mapping_ptr
,
seeds_ptr
,
pos_ptr
,
temp_ptr
,
...
...
@@ -20,22 +21,24 @@ def _gumbel_sample_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits_ptr
+
req
_idx
*
logits_stride
+
block
,
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
mask
=
mask
,
other
=
float
(
"-inf"
),
)
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_idx
).
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_
state_
idx
).
to
(
tl
.
float32
)
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
pos
=
tl
.
load
(
pos_ptr
+
req
_idx
)
seed
=
tl
.
load
(
seeds_ptr
+
req_
state_
idx
)
pos
=
tl
.
load
(
pos_ptr
+
batch
_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
# Generate gumbel noise.
...
...
@@ -55,12 +58,13 @@ def _gumbel_sample_kernel(
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
value
=
tl
.
max
(
logits
,
axis
=
0
)
tl
.
store
(
local_argmax_ptr
+
req
_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
req
_idx
*
local_max_stride
+
block_idx
,
value
)
tl
.
store
(
local_argmax_ptr
+
batch
_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
batch
_idx
*
local_max_stride
+
block_idx
,
value
)
def
gumbel_sample
(
logits
:
torch
.
Tensor
,
# [num_reqs, vocab_size]
idx_mapping
:
torch
.
Tensor
,
# [num_reqs]
temperature
:
torch
.
Tensor
,
# [num_reqs]
seed
:
torch
.
Tensor
,
# [num_reqs]
pos
:
torch
.
Tensor
,
# [num_reqs]
...
...
@@ -88,6 +92,7 @@ def gumbel_sample(
local_max
.
stride
(
0
),
logits
,
logits
.
stride
(
0
),
idx_mapping
,
seed
,
pos
,
temperature
,
...
...
vllm/v1/worker/gpu/sample/metadata.py
View file @
025a32f9
...
...
@@ -4,20 +4,23 @@ from dataclasses import dataclass
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
dataclass
class
SamplingMetadata
:
idx_mapping
:
torch
.
Tensor
temperature
:
torch
.
Tensor
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
min_p
:
torch
.
Tensor
|
None
# For penalties
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
prompt_bin_mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
...
...
@@ -25,11 +28,6 @@ class SamplingMetadata:
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
# For penalties
idx_mapping
:
torch
.
Tensor
prompt_bin_mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
@
classmethod
def
make_dummy
(
cls
,
...
...
@@ -37,6 +35,8 @@ class SamplingMetadata:
device
:
torch
.
device
,
)
->
"SamplingMetadata"
:
assert
num_reqs
>
0
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
[
0
]
=
0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
...
...
@@ -51,18 +51,19 @@ class SamplingMetadata:
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
return
cls
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
...
...
@@ -70,123 +71,9 @@ class SamplingMetadata:
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
prompt_bin_mask
=
prompt_bin_mask
,
output_bin_counts
=
output_bin_counts
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_mask
=
prompt_bin_mask
,
output_bin_counts
=
output_bin_counts
,
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@
triton
.
jit
def
_expand_sampling_metadata_kernel
(
temp_ptr
,
expanded_temp_ptr
,
top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
min_p_ptr
,
expanded_min_p_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
tl
.
store
(
expanded_temp_ptr
+
start_idx
+
block
,
temp
,
mask
=
mask
)
if
top_p_ptr
is
not
None
:
top_p
=
tl
.
load
(
top_p_ptr
+
req_idx
)
tl
.
store
(
expanded_top_p_ptr
+
start_idx
+
block
,
top_p
,
mask
=
mask
)
if
top_k_ptr
is
not
None
:
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
if
min_p_ptr
is
not
None
:
min_p
=
tl
.
load
(
min_p_ptr
+
req_idx
)
tl
.
store
(
expanded_min_p_ptr
+
start_idx
+
block
,
min_p
,
mask
=
mask
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
freq_penalty
=
tl
.
load
(
freq_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_freq_penalty_ptr
+
start_idx
+
block
,
freq_penalty
,
mask
=
mask
)
pres_penalty
=
tl
.
load
(
pres_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_pres_penalty_ptr
+
start_idx
+
block
,
pres_penalty
,
mask
=
mask
)
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
tl
.
store
(
expanded_seeds_ptr
+
start_idx
+
block
,
seed
,
mask
=
mask
)
def
expand_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
max_expand_len
:
int
,
)
->
SamplingMetadata
:
total_num_logits
=
sampling_metadata
.
pos
.
shape
[
0
]
create_empty
=
lambda
x
:
x
.
new_empty
(
total_num_logits
)
if
x
is
not
None
else
None
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_min_p
=
create_empty
(
sampling_metadata
.
min_p
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
expanded_seeds
=
create_empty
(
sampling_metadata
.
seeds
)
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
_expand_sampling_metadata_kernel
[(
num_reqs
,)](
sampling_metadata
.
temperature
,
expanded_temp
,
sampling_metadata
.
top_p
,
expanded_top_p
,
sampling_metadata
.
top_k
,
expanded_top_k
,
sampling_metadata
.
min_p
,
expanded_min_p
,
sampling_metadata
.
repetition_penalty
,
expanded_repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
expanded_frequency_penalty
,
sampling_metadata
.
presence_penalty
,
expanded_presence_penalty
,
sampling_metadata
.
seeds
,
expanded_seeds
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
)
return
SamplingMetadata
(
temperature
=
expanded_temp
,
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
min_p
=
expanded_min_p
,
seeds
=
expanded_seeds
,
repetition_penalty
=
expanded_repetition_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
presence_penalty
=
expanded_presence_penalty
,
pos
=
sampling_metadata
.
pos
,
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping
=
sampling_metadata
.
idx_mapping
,
prompt_bin_mask
=
sampling_metadata
.
prompt_bin_mask
,
output_bin_counts
=
sampling_metadata
.
output_bin_counts
,
)
vllm/v1/worker/gpu/sample/min_p.py
View file @
025a32f9
...
...
@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton
def
_min_p_kernel
(
logits_ptr
,
logits_stride
,
idx_mapping_ptr
,
min_p_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_idx
).
to
(
tl
.
float32
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
min_p
==
0.0
:
return
...
...
@@ -39,12 +41,17 @@ def _min_p_kernel(
tl
.
store
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_min_p
(
logits
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
)
->
None
:
def
apply_min_p
(
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
,
)
->
None
:
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
_min_p_kernel
[(
num_reqs
,)](
logits
,
logits
.
stride
(
0
),
idx_mapping
,
min_p
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
...
...
vllm/v1/worker/gpu/sample/penalties.py
View file @
025a32f9
...
...
@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
def
_penalties_and_temperature_kernel
(
logits_ptr
,
logits_stride
,
idx_mapping_ptr
,
repetition_penalty_ptr
,
frequency_penalty_ptr
,
presence_penalty_ptr
,
temperature_ptr
,
idx_mapping_ptr
,
prompt_bin_mask_ptr
,
prompt_bin_mask_stride
,
output_bin_counts_ptr
,
...
...
@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
):
batch_idx
=
tl
.
program_id
(
0
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
batch_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
batch_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
batch_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
batch_idx
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
req_state_idx
)
temperature
=
tl
.
where
(
temperature
==
0.0
,
1.0
,
temperature
)
use_rep_penalty
=
rep_penalty
!=
1.0
...
...
@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel(
logits
=
logits
.
to
(
tl
.
float32
)
if
use_penalty
:
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
output_bin_counts
=
tl
.
load
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
mask
=
mask
,
...
...
@@ -92,11 +92,11 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
.
stride
(
0
),
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
sampling_metadata
.
presence_penalty
,
sampling_metadata
.
temperature
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
prompt_bin_mask
,
sampling_metadata
.
prompt_bin_mask
.
stride
(
0
),
sampling_metadata
.
output_bin_counts
,
...
...
vllm/v1/worker/gpu/sample/sampler.py
View file @
025a32f9
...
...
@@ -71,7 +71,7 @@ class Sampler:
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
# Apply min_p in place.
if
sampling_metadata
.
min_p
is
not
None
:
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
apply_min_p
(
logits
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p. This might return a new tensor.
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
...
...
@@ -79,6 +79,7 @@ class Sampler:
sampled
=
gumbel_sample
(
logits
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
temperature
,
sampling_metadata
.
seeds
,
sampling_metadata
.
pos
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
025a32f9
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
...
...
@@ -46,7 +44,6 @@ class EagleSpeculator:
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
self
.
vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
input_buffers
=
InputBuffers
(
...
...
@@ -56,7 +53,6 @@ class EagleSpeculator:
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
device
=
device
,
pin_memory
=
self
.
pin_memory
,
)
self
.
hidden_states
=
torch
.
zeros
(
self
.
max_num_tokens
,
...
...
@@ -64,6 +60,11 @@ class EagleSpeculator:
dtype
=
self
.
dtype
,
device
=
device
,
)
self
.
idx_mapping
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
temperature
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
,
...
...
@@ -140,7 +141,7 @@ class EagleSpeculator:
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
)
->
None
:
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
# Run the eagle model.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
...
...
@@ -152,8 +153,9 @@ class EagleSpeculator:
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
self
.
temperature
[:
num_reqs
],
self
.
seeds
[:
num_reqs
],
self
.
idx_mapping
[:
num_reqs
],
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
apply_temperature
=
True
,
)
...
...
@@ -237,23 +239,27 @@ class EagleSpeculator:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
cu_num_logits
=
input_batch
.
cu_num_logits
[:
num_reqs
]
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
temperature
=
self
.
temperature
[:
num_reqs
]
seeds
=
self
.
seeds
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
idx_mapping
.
copy_
(
input_batch
.
idx_mapping
)
self
.
temperature
.
copy_
(
sampling_metadata
.
temperature
)
self
.
seeds
.
copy_
(
sampling_metadata
.
seeds
)
# Gather the values and copy them to the pre-allocated buffers.
torch
.
gather
(
sampling_metadata
.
temperature
,
0
,
cu_num_logits
,
out
=
temperature
)
torch
.
gather
(
sampling_metadata
.
seeds
,
0
,
cu_num_logits
,
out
=
seeds
)
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
logits
,
temperature
,
seeds
,
pos
+
1
,
apply_temperature
=
True
logits
,
idx_mapping
,
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
apply_temperature
=
True
,
)
if
self
.
num_speculative_steps
==
1
:
# Early exit.
...
...
@@ -273,11 +279,8 @@ class EagleSpeculator:
self
.
max_model_len
,
self
.
max_num_reqs
,
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc_gpu
,
pos
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
pos
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
if
cudagraph_size
is
not
None
:
...
...
@@ -286,8 +289,9 @@ class EagleSpeculator:
return
self
.
draft_tokens
[:
num_reqs
]
# Run eager mode.
query_start_loc
.
np
[:
num_reqs
+
1
]
=
np
.
arange
(
num_reqs
+
1
)
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
arange
(
num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
[
x
[:
num_reqs
]
for
x
in
self
.
block_tables
.
input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
...
...
@@ -295,7 +299,7 @@ class EagleSpeculator:
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_tokens
=
num_reqs
,
query_start_loc_gpu
=
query_start_loc
_gpu
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
max_seq_len
=
self
.
max_model_len
,
...
...
@@ -484,7 +488,7 @@ def prepare_eagle_decode(
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
input_buffers
.
query_start_loc
.
gpu
,
input_buffers
.
query_start_loc
,
input_buffers
.
seq_lens
,
hidden_size
,
max_model_len
,
...
...
vllm/v1/worker/gpu/states.py
View file @
025a32f9
...
...
@@ -8,10 +8,8 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.
utils
import
CpuGpuBuffe
r
from
vllm.v1.
worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTenso
r
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.penalties
import
bincount
...
...
@@ -29,7 +27,6 @@ class RequestState:
num_speculative_steps
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
...
...
@@ -37,7 +34,6 @@ class RequestState:
self
.
num_speculative_steps
=
num_speculative_steps
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
self
.
index_to_req_id
:
dict
[
int
,
str
]
=
{}
...
...
@@ -47,16 +43,18 @@ class RequestState:
self
.
prompt_len
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
self
.
prefill_token_ids
=
UvaBuffer
(
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
# To save GPU memory, we use UVA instead of GPU for this tensor.
self
.
prefill_token_ids
=
StagedWriteTensor
(
(
self
.
max_num_reqs
,
self
.
max_model_len
),
dtype
=
torch
.
int32
,
device
=
device
,
uva_instead_of_gpu
=
True
,
)
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view
# can be used outside of update_states and prepare_inputs.
# Without async barrier, using UVA can cause race conditions.
self
.
prefill_len
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
prefill_len
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
# Number of computed tokens.
self
.
num_computed_prefill_tokens
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens
=
torch
.
zeros
(
self
.
num_computed_tokens
=
StagedWriteTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
...
...
@@ -84,14 +82,16 @@ class RequestState:
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
# Sampling parameters.
self
.
temperature
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_k
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
min_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
repetition_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
frequency_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
presence_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
seeds
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int64
)
self
.
temperature
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_k
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
min_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
repetition_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
frequency_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
presence_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
seeds
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# -1 means no logprobs are requested.
...
...
@@ -111,13 +111,7 @@ class RequestState:
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_make_param
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
)
->
"Param"
:
return
Param
(
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
def
_make_buffer
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
self
.
_penalties_reqs
:
list
[
int
]
=
[]
@
property
def
num_reqs
(
self
)
->
int
:
...
...
@@ -144,12 +138,9 @@ class RequestState:
f
"prefill_len
{
prefill_len
}
< prompt_len
{
prompt_len
}
"
)
self
.
prefill_len
.
np
[
req_idx
]
=
prefill_len
self
.
prefill_token_ids
.
np
[
req_idx
,
:
prefill_len
]
=
prefill_token_ids
self
.
prefill_token_ids
.
stage_write
(
req_idx
,
0
,
prefill_token_ids
)
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
# Optimize this.
self
.
num_computed_tokens
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_tokens
.
stage_write_elem
(
req_idx
,
num_computed_tokens
)
if
lora_request
is
not
None
:
self
.
lora_ids
[
req_idx
]
=
lora_request
.
lora_int_id
...
...
@@ -169,13 +160,7 @@ class RequestState:
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
bincount
(
self
.
prefill_token_ids
.
gpu
[
req_idx
],
prefill_len
,
prompt_len
,
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
append
(
req_idx
)
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
...
...
@@ -193,6 +178,22 @@ class RequestState:
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
def
apply_staged_writes
(
self
)
->
None
:
self
.
prefill_len
.
copy_to_uva
()
self
.
prefill_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
bincount
(
self
.
prefill_token_ids
.
gpu
[
req_idx
],
int
(
self
.
prefill_len
.
np
[
req_idx
]),
int
(
self
.
prompt_len
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
clear
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
extra_data
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
...
...
@@ -208,30 +209,25 @@ class RequestState:
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
temperature
=
self
.
temperature
.
np
[
idx_mapping_np
]
temperature
=
self
.
temperature
.
copy_np_to_gpu
(
temperature
)
temperature
=
self
.
temperature
.
copy_to_uva
()
top_p
=
self
.
top_p
.
np
[
idx_mapping_np
]
no_top_p
=
np
.
all
(
top_p
==
1.0
)
top_p
=
self
.
top_p
.
copy_
np_to_gpu
(
top_p
)
if
not
no_top_p
else
None
top_p
=
self
.
top_p
.
copy_
to_uva
()[
idx_mapping
]
if
not
no_top_p
else
None
top_k
=
self
.
top_k
.
np
[
idx_mapping_np
]
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
top_k
=
self
.
top_k
.
copy_
np_to_gpu
(
top_k
)
if
not
no_top_k
else
None
top_k
=
self
.
top_k
.
copy_
to_uva
()[
idx_mapping
]
if
not
no_top_k
else
None
min_p
=
self
.
min_p
.
np
[
idx_mapping_np
]
no_min_p
=
np
.
all
(
min_p
==
0.0
)
min_p
=
self
.
min_p
.
copy_
np_to_gpu
(
min_p
)
if
not
no_min_p
else
None
min_p
=
self
.
min_p
.
copy_
to_uva
(
)
if
not
no_min_p
else
None
rep_penalty
=
self
.
repetition_penalty
.
np
[
idx_mapping_np
]
rep_penalty
=
self
.
repetition_penalty
.
copy_np_to_gpu
(
rep_penalty
)
freq_penalty
=
self
.
frequency_penalty
.
np
[
idx_mapping_np
]
freq_penalty
=
self
.
frequency_penalty
.
copy_np_to_gpu
(
freq_penalty
)
pres_penalty
=
self
.
presence_penalty
.
np
[
idx_mapping_np
]
pres_penalty
=
self
.
presence_penalty
.
copy_np_to_gpu
(
pres_penalty
)
rep_penalty
=
self
.
repetition_penalty
.
copy_to_uva
()
freq_penalty
=
self
.
frequency_penalty
.
copy_to_uva
()
pres_penalty
=
self
.
presence_penalty
.
copy_to_uva
()
seeds
=
self
.
seeds
.
np
[
idx_mapping_np
]
seeds
=
self
.
seeds
.
copy_np_to_gpu
(
seeds
)
seeds
=
self
.
seeds
.
copy_to_uva
()
num_logprobs
=
self
.
num_logprobs
[
idx_mapping_np
]
max_num_logprobs
:
int
|
None
=
int
(
np
.
max
(
num_logprobs
))
...
...
@@ -239,6 +235,7 @@ class RequestState:
max_num_logprobs
=
None
return
SamplingMetadata
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
...
...
@@ -246,12 +243,11 @@ class RequestState:
repetition_penalty
=
rep_penalty
,
frequency_penalty
=
freq_penalty
,
presence_penalty
=
pres_penalty
,
prompt_bin_mask
=
self
.
prompt_bin_mask
,
output_bin_counts
=
self
.
output_bin_counts
,
seeds
=
seeds
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_mask
=
self
.
prompt_bin_mask
,
output_bin_counts
=
self
.
output_bin_counts
,
)
def
make_lora_inputs
(
...
...
@@ -272,42 +268,12 @@ class RequestState:
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
class
Param
:
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
self
.
buffer
=
CpuGpuBuffer
(
size
,
dtype
=
dtype
,
device
=
device
,
pin_memory
=
pin_memory
,
)
self
.
np
=
np
.
zeros_like
(
self
.
buffer
.
np
)
def
copy_np_to_gpu
(
self
,
x
:
np
.
ndarray
)
->
torch
.
Tensor
:
n
=
x
.
shape
[
0
]
self
.
buffer
.
np
[:
n
]
=
x
return
self
.
buffer
.
copy_to_gpu
(
n
)
@
dataclass
class
ExtraData
:
lora_request
:
LoRARequest
|
None
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
class
UvaBuffer
:
def
__init__
(
self
,
*
size
:
int
|
torch
.
SymInt
,
dtype
:
torch
.
dtype
):
assert
is_uva_available
()
self
.
cpu
=
torch
.
zeros
(
*
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
sampling_params
.
repetition_penalty
!=
1.0
...
...
vllm/v1/worker/gpu/structured_outputs.py
View file @
025a32f9
...
...
@@ -4,35 +4,62 @@ import numpy as np
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
def
apply_grammar_bitmask
(
class
StructuredOutputsWorker
:
def
__init__
(
self
,
max_num_logits
:
int
,
vocab_size
:
int
,
):
# NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
# to save a unnecessary CPU-to-CPU copy.
self
.
logits_indices
=
UvaBufferPool
(
max_num_logits
,
torch
.
int32
)
self
.
grammar_bitmask
=
UvaBufferPool
(
(
max_num_logits
,
cdiv
(
vocab_size
,
32
)),
torch
.
int32
)
def
apply_grammar_bitmask
(
self
,
logits
:
torch
.
Tensor
,
req_ids
:
list
[
str
]
,
input_batch
:
InputBatch
,
grammar_req_ids
:
list
[
str
],
grammar_bitmask
:
np
.
ndarray
,
input_buffers
:
InputBuffers
,
)
->
None
:
input_buffers
.
grammar_bitmask
.
np
[:
grammar_bitmask
.
shape
[
0
]]
=
grammar_bitmask
input_buffers
.
grammar_bitmask
.
copy_to_gpu
(
grammar_bitmask
.
shape
[
0
])
)
->
None
:
if
not
grammar_req_ids
:
return
batch_size
=
logits
.
shape
[
0
]
grammar_req_id_to_idx
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
grammar_req_ids
)}
# logits -> bitmask mapping
mapping
=
[
grammar_req_id_to_idx
.
get
(
req_id
,
-
1
)
for
req_id
in
req_ids
]
input_buffers
.
bitmask_indices
.
np
[:
batch_size
]
=
mapping
input_buffers
.
bitmask_indices
.
copy_to_gpu
(
batch_size
)
# Construct bitmask -> logits mapping
mapping
:
list
[
int
]
=
[]
req_ids
=
input_batch
.
req_ids
cu_num_logits
=
input_batch
.
cu_num_logits_np
.
tolist
()
req_id_to_idx
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
req_ids
)}
for
grammar_req_id
in
grammar_req_ids
:
req_idx
=
req_id_to_idx
[
grammar_req_id
]
logits_start_idx
=
cu_num_logits
[
req_idx
]
logits_end_idx
=
cu_num_logits
[
req_idx
+
1
]
mapping
.
extend
(
range
(
logits_start_idx
,
logits_end_idx
))
# Copy the mapping.
mapping_np
=
np
.
array
(
mapping
,
dtype
=
np
.
int32
)
logits_indices
=
self
.
logits_indices
.
copy_to_uva
(
mapping_np
)
# Copy the bitmask.
bitmask
=
self
.
grammar_bitmask
.
copy_to_uva
(
grammar_bitmask
)
num_masks
=
bitmask
.
shape
[
0
]
assert
num_masks
==
len
(
mapping
)
vocab_size
=
logits
.
shape
[
-
1
]
BLOCK_SIZE
=
8192
grid
=
(
batch_size
,
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
))
grid
=
(
num_masks
,
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
))
_apply_grammar_bitmask_kernel
[
grid
](
logits
,
logits
.
stride
(
0
),
input_buffers
.
grammar_bitmask
.
gpu
,
input_buffers
.
grammar_bitmask
.
gpu
.
stride
(
0
)
,
input_buffers
.
bitmask_indices
.
gpu
,
logits_indices
,
bitmask
,
bitmask
.
stride
(
0
)
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
...
...
@@ -44,17 +71,14 @@ def apply_grammar_bitmask(
def
_apply_grammar_bitmask_kernel
(
logits_ptr
,
logits_stride
,
logits_indices_ptr
,
bitmask_ptr
,
bitmask_stride
,
bitmask_indices_ptr
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
logits_idx
=
tl
.
program_id
(
0
)
bitmask_idx
=
tl
.
load
(
bitmask_indices_ptr
+
logits_idx
)
if
bitmask_idx
==
-
1
:
# No bitmask to apply.
return
bitmask_idx
=
tl
.
program_id
(
0
)
logits_idx
=
tl
.
load
(
logits_indices_ptr
+
bitmask_idx
)
# Load the bitmask.
block_id
=
tl
.
program_id
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment