Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
025a32f9
Unverified
Commit
025a32f9
authored
Jan 11, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 11, 2026
Browse files
[Model Runner V2] Remove async barrier (#32083)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
19504ac0
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
590 additions
and
462 deletions
+590
-462
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+18
-34
vllm/v1/worker/gpu/buffer_utils.py
vllm/v1/worker/gpu/buffer_utils.py
+218
-0
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+9
-6
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+57
-25
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+108
-111
vllm/v1/worker/gpu/sample/gumbel.py
vllm/v1/worker/gpu/sample/gumbel.py
+12
-7
vllm/v1/worker/gpu/sample/metadata.py
vllm/v1/worker/gpu/sample/metadata.py
+14
-127
vllm/v1/worker/gpu/sample/min_p.py
vllm/v1/worker/gpu/sample/min_p.py
+9
-2
vllm/v1/worker/gpu/sample/penalties.py
vllm/v1/worker/gpu/sample/penalties.py
+7
-7
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+2
-1
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+27
-23
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+51
-85
vllm/v1/worker/gpu/structured_outputs.py
vllm/v1/worker/gpu/structured_outputs.py
+58
-34
No files found.
vllm/v1/worker/gpu/block_table.py
View file @
025a32f9
...
@@ -6,9 +6,8 @@ import torch
...
@@ -6,9 +6,8 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
class
BlockTables
:
class
BlockTables
:
...
@@ -26,19 +25,16 @@ class BlockTables:
...
@@ -26,19 +25,16 @@ class BlockTables:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
device
=
device
self
.
device
=
device
if
not
is_uva_available
():
raise
RuntimeError
(
"UVA is not available"
)
self
.
num_kv_cache_groups
=
len
(
self
.
block_sizes
)
self
.
num_kv_cache_groups
=
len
(
self
.
block_sizes
)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self
.
block_tables
:
list
[
UvaBuffe
r
]
=
[]
self
.
block_tables
:
list
[
StagedWriteTenso
r
]
=
[]
for
i
in
range
(
self
.
num_kv_cache_groups
):
for
i
in
range
(
self
.
num_kv_cache_groups
):
block_size
=
self
.
block_sizes
[
i
]
block_size
=
self
.
block_sizes
[
i
]
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
block_size
)
max_num_blocks
=
cdiv
(
self
.
max_model_len
,
block_size
)
block_table
=
UvaBuffer
(
block_table
=
StagedWriteTensor
(
self
.
max_num_reqs
,
(
self
.
max_num_reqs
,
max_num_blocks
),
max_num_blocks
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
,
)
)
self
.
block_tables
.
append
(
block_table
)
self
.
block_tables
.
append
(
block_table
)
self
.
block_table_ptrs
=
self
.
_make_ptr_tensor
(
self
.
block_table_ptrs
=
self
.
_make_ptr_tensor
(
...
@@ -53,9 +49,8 @@ class BlockTables:
...
@@ -53,9 +49,8 @@ class BlockTables:
self
.
block_sizes_tensor
=
torch
.
tensor
(
self
.
block_sizes_tensor
=
torch
.
tensor
(
self
.
block_sizes
,
dtype
=
torch
.
int32
,
device
=
self
.
device
self
.
block_sizes
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
num_blocks
=
UvaBuffer
(
self
.
num_blocks
=
UvaBackedTensor
(
self
.
num_kv_cache_groups
,
(
self
.
num_kv_cache_groups
,
self
.
max_num_reqs
),
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
)
)
...
@@ -75,13 +70,11 @@ class BlockTables:
...
@@ -75,13 +70,11 @@ class BlockTables:
def
_make_ptr_tensor
(
self
,
x
:
Iterable
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
_make_ptr_tensor
(
self
,
x
:
Iterable
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu
=
torch
.
tensor
(
return
torch
.
tensor
(
[
t
.
data_ptr
()
for
t
in
x
],
[
t
.
data_ptr
()
for
t
in
x
],
dtype
=
torch
.
uint64
,
dtype
=
torch
.
uint64
,
device
=
"cpu"
,
device
=
self
.
device
,
pin_memory
=
True
,
)
)
return
ptrs_tensor_cpu
.
to
(
self
.
device
,
non_blocking
=
True
)
def
append_block_ids
(
def
append_block_ids
(
self
,
self
,
...
@@ -90,19 +83,17 @@ class BlockTables:
...
@@ -90,19 +83,17 @@ class BlockTables:
overwrite
:
bool
,
overwrite
:
bool
,
)
->
None
:
)
->
None
:
for
i
in
range
(
self
.
num_kv_cache_groups
):
for
i
in
range
(
self
.
num_kv_cache_groups
):
start
=
self
.
num_blocks
.
np
[
i
,
req_index
]
if
not
overwrite
else
0
block_ids
=
new_block_ids
[
i
]
block_ids
=
new_block_ids
[
i
]
num_new_blocks
=
len
(
block_ids
)
self
.
block_tables
[
i
].
stage_write
(
req_index
,
start
,
block_ids
)
if
num_new_blocks
==
0
:
self
.
num_blocks
.
np
[
i
,
req_index
]
=
start
+
len
(
block_ids
)
continue
# TODO(woosuk): Too many Numpy invocations. Optimize this.
def
apply_staged_writes
(
self
)
->
None
:
start
=
self
.
num_blocks
.
np
[
i
,
req_index
]
if
not
overwrite
else
0
# TODO(woosuk): This can be inefficient since it launches one kernel per
end
=
start
+
num_new_blocks
# block table. Implement a kernel to handle all block tables at once.
if
num_new_blocks
==
1
:
for
block_table
in
self
.
block_tables
:
self
.
block_tables
[
i
].
np
[
req_index
,
start
]
=
block_ids
[
0
]
block_table
.
apply_write
()
else
:
self
.
num_blocks
.
copy_to_uva
()
self
.
block_tables
[
i
].
np
[
req_index
,
start
:
end
]
=
block_ids
self
.
num_blocks
.
np
[
i
,
req_index
]
=
end
def
gather_block_tables
(
def
gather_block_tables
(
self
,
self
,
...
@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
...
@@ -229,10 +220,3 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
ptr
=
tl
.
load
(
ptr_to_ptr
)
ptr
=
tl
.
load
(
ptr_to_ptr
)
ptr
=
tl
.
cast
(
ptr
,
tl
.
pointer_type
(
elem_dtype
))
ptr
=
tl
.
cast
(
ptr
,
tl
.
pointer_type
(
elem_dtype
))
return
tl
.
multiple_of
(
ptr
,
16
)
return
tl
.
multiple_of
(
ptr
,
16
)
class
UvaBuffer
:
def
__init__
(
self
,
*
size
,
dtype
:
torch
.
dtype
):
self
.
cpu
=
torch
.
zeros
(
*
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
vllm/v1/worker/gpu/buffer_utils.py
0 → 100644
View file @
025a32f9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
import
numpy
as
np
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
class
UvaBuffer
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
):
if
not
is_uva_available
():
raise
RuntimeError
(
"UVA is not available"
)
self
.
cpu
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
uva
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
class
UvaBufferPool
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
,
max_concurrency
:
int
=
2
,
):
self
.
size
=
size
self
.
dtype
=
dtype
self
.
max_concurrency
=
max_concurrency
# UVA buffers for concurrency
self
.
_uva_bufs
=
[
UvaBuffer
(
size
,
dtype
)
for
_
in
range
(
max_concurrency
)]
# Current buffer index
self
.
_curr
=
0
def
copy_to_uva
(
self
,
x
:
torch
.
Tensor
|
np
.
ndarray
|
list
)
->
torch
.
Tensor
:
# Round robin to the next buffer.
self
.
_curr
=
(
self
.
_curr
+
1
)
%
self
.
max_concurrency
buf
=
self
.
_uva_bufs
[
self
.
_curr
]
# CPU-to-CPU copy
dst
=
buf
.
cpu
if
isinstance
(
x
,
torch
.
Tensor
)
else
buf
.
np
n
=
len
(
x
)
dst
[:
n
]
=
x
return
buf
.
uva
[:
n
]
def
copy_to_gpu
(
self
,
x
:
torch
.
Tensor
|
np
.
ndarray
,
out
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
uva
=
self
.
copy_to_uva
(
x
)
if
out
is
None
:
# CPU-to-GPU copy
return
uva
.
clone
()
# CPU-to-GPU copy
return
out
.
copy_
(
uva
,
non_blocking
=
True
)
class
UvaBackedTensor
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
,
max_concurrency
:
int
=
2
,
):
self
.
dtype
=
dtype
self
.
max_concurrency
=
max_concurrency
# Source of truth
self
.
cpu
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
False
)
self
.
np
=
self
.
cpu
.
numpy
()
# Buffers for concurrency
self
.
pool
=
UvaBufferPool
(
size
,
dtype
,
max_concurrency
)
self
.
gpu
=
self
.
pool
.
copy_to_uva
(
self
.
np
)
def
copy_to_uva
(
self
,
n
:
int
|
None
=
None
)
->
torch
.
Tensor
:
# CPU-to-CPU copy
self
.
gpu
=
self
.
pool
.
copy_to_uva
(
self
.
np
[:
n
]
if
n
is
not
None
else
self
.
np
)
return
self
.
gpu
class
StagedWriteTensor
:
def
__init__
(
self
,
size
:
int
|
Sequence
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
max_concurrency
:
int
=
2
,
uva_instead_of_gpu
:
bool
=
False
,
):
if
dtype
not
in
[
torch
.
int32
,
torch
.
int64
]:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
: should be either int32 or int64"
)
self
.
num_rows
=
size
if
isinstance
(
size
,
int
)
else
size
[
0
]
self
.
dtype
=
dtype
self
.
max_concurrency
=
max_concurrency
if
not
uva_instead_of_gpu
:
# Create a GPU tensor (default)
self
.
gpu
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
device
)
else
:
# For a large but not-frequently-accessed tensor, we can use UVA instead of
# GPU to save GPU memory
self
.
_uva_buf
=
UvaBuffer
(
size
,
dtype
)
self
.
gpu
=
self
.
_uva_buf
.
uva
self
.
_staged_write_indices
:
list
[
int
]
=
[]
self
.
_staged_write_starts
:
list
[
int
]
=
[]
self
.
_staged_write_contents
:
list
[
int
]
=
[]
self
.
_staged_write_cu_lens
:
list
[
int
]
=
[]
self
.
write_indices
=
UvaBufferPool
(
self
.
num_rows
,
dtype
=
torch
.
int32
,
max_concurrency
=
max_concurrency
)
self
.
write_starts
=
UvaBufferPool
(
self
.
num_rows
,
dtype
=
torch
.
int32
,
max_concurrency
=
max_concurrency
)
init_size
=
next_power_of_2
(
self
.
num_rows
)
self
.
write_contents
=
UvaBufferPool
(
init_size
,
dtype
=
dtype
,
max_concurrency
=
max_concurrency
)
self
.
write_cu_lens
=
UvaBufferPool
(
self
.
num_rows
,
dtype
=
torch
.
int32
,
max_concurrency
=
max_concurrency
)
def
stage_write
(
self
,
index
:
int
,
start
:
int
,
x
:
list
[
int
])
->
None
:
assert
index
>=
0
assert
start
>=
0
if
not
x
:
return
self
.
_staged_write_indices
.
append
(
index
)
self
.
_staged_write_starts
.
append
(
start
)
self
.
_staged_write_contents
.
extend
(
x
)
self
.
_staged_write_cu_lens
.
append
(
len
(
self
.
_staged_write_contents
))
def
stage_write_elem
(
self
,
index
:
int
,
x
:
int
)
->
None
:
assert
index
>=
0
self
.
_staged_write_indices
.
append
(
index
)
self
.
_staged_write_starts
.
append
(
0
)
self
.
_staged_write_contents
.
append
(
x
)
self
.
_staged_write_cu_lens
.
append
(
len
(
self
.
_staged_write_contents
))
def
apply_write
(
self
)
->
None
:
n
=
len
(
self
.
_staged_write_indices
)
if
n
==
0
:
return
indices_uva
=
self
.
write_indices
.
copy_to_uva
(
self
.
_staged_write_indices
)
starts_uva
=
self
.
write_starts
.
copy_to_uva
(
self
.
_staged_write_starts
)
cu_lens_uva
=
self
.
write_cu_lens
.
copy_to_uva
(
self
.
_staged_write_cu_lens
)
# Special handling for write_contents
diff_len
=
len
(
self
.
_staged_write_contents
)
assert
isinstance
(
self
.
write_contents
.
size
,
int
)
if
diff_len
>
self
.
write_contents
.
size
:
# Re-allocate a larger buffer for the write_contents
new_size
=
next_power_of_2
(
diff_len
)
self
.
write_contents
=
UvaBufferPool
(
new_size
,
dtype
=
self
.
dtype
,
max_concurrency
=
self
.
max_concurrency
)
# NOTE(woosuk): Since the previous write_contents buffer is released,
# we perform a synchronization here to ensure that all data transfers
# involving the old buffer have finished before allocating a new one.
# This prevents potential race conditions. The slight overhead is
# negligible because the reallocations are infrequent in practice.
torch
.
cuda
.
synchronize
()
contents_uva
=
self
.
write_contents
.
copy_to_uva
(
self
.
_staged_write_contents
)
# Write diffs to the GPU buffer
_apply_write_kernel
[(
n
,)](
self
.
gpu
,
self
.
gpu
.
stride
(
0
),
indices_uva
,
starts_uva
,
contents_uva
,
cu_lens_uva
,
BLOCK_SIZE
=
1024
,
)
# Clear the staged writes
self
.
clear_staged_writes
()
def
clear_staged_writes
(
self
)
->
None
:
self
.
_staged_write_indices
.
clear
()
self
.
_staged_write_starts
.
clear
()
self
.
_staged_write_contents
.
clear
()
self
.
_staged_write_cu_lens
.
clear
()
@
triton
.
jit
def
_apply_write_kernel
(
output_ptr
,
output_stride
,
write_indices_ptr
,
write_starts_ptr
,
write_contents_ptr
,
write_cu_lens_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
row_idx
=
tl
.
load
(
write_indices_ptr
+
pid
)
start_idx
=
tl
.
load
(
write_starts_ptr
+
pid
)
cu_start
=
tl
.
load
(
write_cu_lens_ptr
+
pid
-
1
)
if
pid
>
0
else
0
cu_end
=
tl
.
load
(
write_cu_lens_ptr
+
pid
)
content_len
=
cu_end
-
cu_start
for
i
in
range
(
0
,
content_len
,
BLOCK_SIZE
):
block
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
content_len
content
=
tl
.
load
(
write_contents_ptr
+
cu_start
+
block
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
row_idx
*
output_stride
+
start_idx
+
block
,
content
,
mask
=
mask
)
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
025a32f9
...
@@ -228,10 +228,13 @@ def prepare_inputs_to_capture(
...
@@ -228,10 +228,13 @@ def prepare_inputs_to_capture(
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
num_tokens_per_req
=
num_tokens
//
num_reqs
num_tokens_per_req
=
num_tokens
//
num_reqs
query_start_loc
=
input_buffers
.
query_start_loc
query_start_loc
.
np
[:
num_reqs
+
1
]
=
np
.
arange
(
num_reqs
+
1
)
*
num_tokens_per_req
query_start_loc_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
*
num_tokens_per_req
query_start_loc
.
np
[
num_reqs
:]
=
num_tokens
query_start_loc_np
[
-
1
]
=
num_tokens
query_start_loc
.
copy_to_gpu
()
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
=
query_start_loc_cpu
input_buffers
.
query_start_loc
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc
=
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
# rather than max_model_len.
...
@@ -245,8 +248,8 @@ def prepare_inputs_to_capture(
...
@@ -245,8 +248,8 @@ def prepare_inputs_to_capture(
attn_metadata_builders
=
attn_metadata_builders
,
attn_metadata_builders
=
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
num_reqs
+
1
]
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
num_reqs
+
1
]
,
query_start_loc_cpu
=
query_start_loc
_
cpu
,
seq_lens
=
input_buffers
.
seq_lens
,
seq_lens
=
input_buffers
.
seq_lens
,
max_seq_len
=
max_model_len
,
max_seq_len
=
max_model_len
,
block_tables
=
input_block_tables
,
block_tables
=
input_block_tables
,
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
025a32f9
...
@@ -8,8 +8,6 @@ import torch
...
@@ -8,8 +8,6 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.utils
import
CpuGpuBuffer
class
InputBuffers
:
class
InputBuffers
:
...
@@ -21,30 +19,17 @@ class InputBuffers:
...
@@ -21,30 +19,17 @@ class InputBuffers:
vocab_size
:
int
,
vocab_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_tokens
=
max_num_tokens
self
.
max_num_tokens
=
max_num_tokens
self
.
device
=
device
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
idx_mapping
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
input_ids
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
query_start_loc
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
self
.
query_start_loc
=
torch
.
zeros
(
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
self
.
cu_num_logits
=
self
.
_make_buffer
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
# Structured outputs.
self
.
bitmask_indices
=
self
.
_make_buffer
(
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
grammar_bitmask
=
self
.
_make_buffer
(
max_num_reqs
,
cdiv
(
vocab_size
,
32
),
dtype
=
torch
.
int32
)
def
_make_buffer
(
self
,
*
args
,
dtype
:
torch
.
dtype
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
*
args
,
dtype
=
dtype
,
pin_memory
=
self
.
pin_memory
,
device
=
self
.
device
)
)
self
.
seq_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
@
dataclass
@
dataclass
...
@@ -56,6 +41,8 @@ class InputBatch:
...
@@ -56,6 +41,8 @@ class InputBatch:
# batch_idx -> req_state_idx
# batch_idx -> req_state_idx
idx_mapping
:
torch
.
Tensor
idx_mapping
:
torch
.
Tensor
idx_mapping_np
:
np
.
ndarray
idx_mapping_np
:
np
.
ndarray
# Identical to idx_mapping except for spec decoding.
expanded_idx_mapping
:
torch
.
Tensor
# [num_reqs]
# [num_reqs]
# batch_idx -> num_scheduled_tokens
# batch_idx -> num_scheduled_tokens
...
@@ -83,6 +70,7 @@ class InputBatch:
...
@@ -83,6 +70,7 @@ class InputBatch:
logits_indices
:
torch
.
Tensor
logits_indices
:
torch
.
Tensor
# [num_reqs + 1]
# [num_reqs + 1]
cu_num_logits
:
torch
.
Tensor
cu_num_logits
:
torch
.
Tensor
cu_num_logits_np
:
np
.
ndarray
@
classmethod
@
classmethod
def
make_dummy
(
def
make_dummy
(
...
@@ -96,33 +84,41 @@ class InputBatch:
...
@@ -96,33 +84,41 @@ class InputBatch:
req_ids
=
[
f
"req_
{
i
}
_
{
random_uuid
()
}
"
for
i
in
range
(
num_reqs
)]
req_ids
=
[
f
"req_
{
i
}
_
{
random_uuid
()
}
"
for
i
in
range
(
num_reqs
)]
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
idx_mapping_np
=
np
.
arange
(
num_reqs
,
dtype
=
np
.
int32
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
expanded_idx_mapping
=
idx_mapping
num_scheduled_tokens
=
np
.
full
(
num_reqs
,
num_tokens
//
num_reqs
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
full
(
num_reqs
,
num_tokens
//
num_reqs
,
dtype
=
np
.
int32
)
num_scheduled_tokens
[
-
1
]
+=
num_tokens
%
num_reqs
num_scheduled_tokens
[
-
1
]
+=
num_tokens
%
num_reqs
assert
int
(
num_scheduled_tokens
.
sum
())
==
num_tokens
assert
int
(
num_scheduled_tokens
.
sum
())
==
num_tokens
input_buffers
.
query_start_loc
.
np
[
0
]
=
0
input_buffers
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
np
.
cumsum
(
num_scheduled_tokens
)
input_buffers
.
query_start_loc
.
np
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc_np
=
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
query_start_loc
=
input_buffers
.
query_start_loc
.
copy_to_gpu
()[:
num_reqs
+
1
]
# seq_len equals to query_len
# seq_len equals to query_len
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
//
num_reqs
input_buffers
.
seq_lens
[:
num_reqs
]
=
num_tokens
//
num_reqs
input_buffers
.
seq_lens
[
num_reqs
-
1
]
+=
num_tokens
%
num_reqs
input_buffers
.
seq_lens
[
num_reqs
-
1
]
+=
num_tokens
%
num_reqs
# Pad for full CUDA graph mode.
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
input_buffers
.
seq_lens
[
num_reqs
:]
=
0
seq_lens
=
input_buffers
.
seq_lens
[:
num_reqs
]
seq_lens
=
input_buffers
.
seq_lens
[:
num_reqs
]
query_start_loc_np
=
np
.
empty
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
query_start_loc_np
[
0
]
=
0
np
.
cumsum
(
num_scheduled_tokens
,
out
=
query_start_loc_np
[
1
:])
input_buffers
.
query_start_loc
[
0
]
=
0
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
out
=
input_buffers
.
query_start_loc
[
1
:
num_reqs
+
1
]
)
# Pad for full CUDA graph mode.
input_buffers
.
query_start_loc
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc
=
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
input_ids
=
input_buffers
.
input_ids
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
positions
=
input_buffers
.
positions
[:
num_tokens
]
# attn_metadata = defaultdict(lambda: None)
# attn_metadata = defaultdict(lambda: None)
logits_indices
=
query_start_loc
[
1
:]
-
1
logits_indices
=
query_start_loc
[
1
:]
-
1
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_num_logits_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
return
cls
(
return
cls
(
req_ids
=
req_ids
,
req_ids
=
req_ids
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
idx_mapping
=
idx_mapping
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
num_tokens_after_padding
=
num_tokens
,
...
@@ -135,6 +131,7 @@ class InputBatch:
...
@@ -135,6 +131,7 @@ class InputBatch:
attn_metadata
=
None
,
# type: ignore
attn_metadata
=
None
,
# type: ignore
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
)
)
...
@@ -473,3 +470,38 @@ def post_update(
...
@@ -473,3 +470,38 @@ def post_update(
query_start_loc
,
query_start_loc
,
num_warps
=
1
,
num_warps
=
1
,
)
)
@
triton
.
jit
def
_expand_idx_mapping_kernel
(
idx_mapping_ptr
,
expanded_idx_mapping_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
tl
.
store
(
expanded_idx_mapping_ptr
+
start_idx
+
block
,
req_state_idx
,
mask
=
mask
)
def
expand_idx_mapping
(
idx_mapping
:
torch
.
Tensor
,
total_num_logits
:
int
,
cu_num_logits
:
torch
.
Tensor
,
max_expand_len
:
int
,
)
->
torch
.
Tensor
:
num_reqs
=
idx_mapping
.
shape
[
0
]
expanded_idx_mapping
=
idx_mapping
.
new_empty
(
total_num_logits
)
_expand_idx_mapping_kernel
[(
num_reqs
,)](
idx_mapping
,
expanded_idx_mapping
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
)
return
expanded_idx_mapping
vllm/v1/worker/gpu/model_runner.py
View file @
025a32f9
...
@@ -15,7 +15,6 @@ from vllm.forward_context import set_forward_context
...
@@ -15,7 +15,6 @@ from vllm.forward_context import set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
@@ -24,7 +23,7 @@ from vllm.v1.outputs import (
...
@@ -24,7 +23,7 @@ from vllm.v1.outputs import (
LogprobsTensors
,
LogprobsTensors
,
ModelRunnerOutput
,
ModelRunnerOutput
,
)
)
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
,
async_barrier
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
build_attn_metadata
,
get_kv_cache_spec
,
get_kv_cache_spec
,
...
@@ -32,6 +31,7 @@ from vllm.v1.worker.gpu.attn_utils import (
...
@@ -32,6 +31,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_kv_cache
,
init_kv_cache
,
)
)
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.cudagraph_utils
import
CudaGraphManager
from
vllm.v1.worker.gpu.dp_utils
import
(
from
vllm.v1.worker.gpu.dp_utils
import
(
get_batch_metadata_across_dp
,
get_batch_metadata_across_dp
,
...
@@ -41,22 +41,20 @@ from vllm.v1.worker.gpu.input_batch import (
...
@@ -41,22 +41,20 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch
,
InputBatch
,
InputBuffers
,
InputBuffers
,
combine_sampled_and_draft_tokens
,
combine_sampled_and_draft_tokens
,
expand_idx_mapping
,
get_num_sampled_and_rejected
,
get_num_sampled_and_rejected
,
post_update
,
post_update
,
prepare_pos_seq_lens
,
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
prepare_prefill_inputs
,
)
)
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.logprob
import
compute_prompt_logprobs
from
vllm.v1.worker.gpu.sample.metadata
import
(
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
SamplingMetadata
,
expand_sampling_metadata
,
)
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode
import
init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.structured_outputs
import
apply_grammar_bitmask
from
vllm.v1.worker.gpu.structured_outputs
import
StructuredOutputsWorker
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
...
@@ -81,7 +79,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -81,7 +79,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
device
=
device
self
.
device
=
device
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
self
.
kv_cache_dtype
=
self
.
dtype
self
.
kv_cache_dtype
=
self
.
dtype
if
self
.
cache_config
.
cache_dtype
!=
"auto"
:
if
self
.
cache_config
.
cache_dtype
!=
"auto"
:
...
@@ -123,7 +120,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -123,7 +120,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_speculative_steps
=
self
.
num_speculative_steps
,
num_speculative_steps
=
self
.
num_speculative_steps
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
)
self
.
input_buffers
=
InputBuffers
(
self
.
input_buffers
=
InputBuffers
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
...
@@ -132,12 +128,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -132,12 +128,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
)
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
# CUDA graphs.
# CUDA graphs.
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
device
)
self
.
cudagraph_manager
=
CudaGraphManager
(
self
.
vllm_config
,
self
.
device
)
# Structured outputs worker.
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
vocab_size
=
self
.
vocab_size
,
)
# Buffers for CPU-to-GPU copies.
self
.
tmp_idx_mapping
=
UvaBufferPool
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
tmp_cu_num_logits
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
tmp_query_start_loc
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
...
@@ -228,16 +233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -228,16 +233,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
slot_mappings
=
self
.
block_tables
.
get_dummy_slot_mappings
(
input_batch
.
num_tokens
input_batch
.
num_tokens
)
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
input_batch
.
num_reqs
+
1
]
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
input_batch
.
num_reqs
+
1
]
attn_metadata
=
build_attn_metadata
(
attn_metadata
=
build_attn_metadata
(
attn_metadata_builders
=
self
.
attn_metadata_builders
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
input_batch
.
num_reqs
,
num_reqs
=
input_batch
.
num_reqs
,
num_tokens
=
input_batch
.
num_tokens
,
num_tokens
=
input_batch
.
num_tokens
,
query_start_loc_gpu
=
query_start_loc
_gpu
,
query_start_loc_gpu
=
input_batch
.
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_
cpu
,
query_start_loc_cpu
=
torch
.
from_numpy
(
input_batch
.
query_start_loc_
np
)
,
seq_lens
=
self
.
input_b
uffers
.
seq_lens
,
seq_lens
=
input_b
atch
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
max_seq_len
=
self
.
max_model_len
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slot_mappings
=
slot_mappings
,
slot_mappings
=
slot_mappings
,
...
@@ -396,8 +398,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -396,8 +398,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
block_tables
.
append_block_ids
(
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
)
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
prefill_len
.
copy_to_gpu
()
# Add new blocks for the existing requests.
# Add new blocks for the existing requests.
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
...
@@ -409,6 +409,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -409,6 +409,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index
,
req_new_block_ids
,
overwrite
=
False
req_index
,
req_new_block_ids
,
overwrite
=
False
)
)
self
.
req_states
.
apply_staged_writes
()
self
.
block_tables
.
apply_staged_writes
()
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
scheduler_output
:
SchedulerOutput
,
scheduler_output
:
SchedulerOutput
,
...
@@ -431,19 +434,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -431,19 +434,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_list
=
[
idx_mapping_list
=
[
self
.
req_states
.
req_id_to_index
[
req_id
]
for
req_id
in
req_ids
self
.
req_states
.
req_id_to_index
[
req_id
]
for
req_id
in
req_ids
]
]
idx_mapping
=
self
.
input_buffers
.
idx_mapping
idx_mapping_np
=
np
.
array
(
idx_mapping_list
,
dtype
=
np
.
int32
)
idx_mapping
.
np
[:
num_reqs
]
=
idx_mapping_list
idx_mapping
=
self
.
tmp_idx_mapping
.
copy_to_gpu
(
idx_mapping_np
)
idx_mapping_np
=
idx_mapping
.
np
[:
num_reqs
]
idx_mapping
=
idx_mapping
.
copy_to_gpu
(
num_reqs
)
# Get the number of draft tokens for each request.
# Get the number of draft tokens for each request.
if
not
scheduler_output
.
scheduled_spec_decode_tokens
:
if
not
scheduler_output
.
scheduled_spec_decode_tokens
:
# No draft token scheduled (common case).
# No draft token scheduled (common case).
total_num_draft_tokens
=
0
total_num_draft_tokens
=
0
total_num_logits
=
num_reqs
total_num_logits
=
num_reqs
cu_num_logits_np
=
np
.
arange
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
cu_num_logits
=
torch
.
arange
(
cu_num_logits
=
torch
.
arange
(
num_reqs
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
num_reqs
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
)
expanded_idx_mapping
=
idx_mapping
else
:
else
:
draft_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
draft_tokens
=
scheduler_output
.
scheduled_spec_decode_tokens
num_draft_tokens
=
np
.
array
(
num_draft_tokens
=
np
.
array
(
...
@@ -456,44 +459,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -456,44 +459,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
total_num_draft_tokens
=
int
(
num_draft_tokens
.
sum
())
total_num_draft_tokens
=
int
(
num_draft_tokens
.
sum
())
total_num_logits
=
num_reqs
+
total_num_draft_tokens
total_num_logits
=
num_reqs
+
total_num_draft_tokens
np
.
cumsum
(
num_logits
=
num_draft_tokens
+
1
num_draft_tokens
+
1
,
cu_num_logits_np
=
np
.
empty
(
num_reqs
+
1
,
dtype
=
np
.
int32
)
out
=
self
.
input_buffers
.
cu_num_logits
.
np
[
1
:
num_reqs
+
1
],
cu_num_logits_np
[
0
]
=
0
np
.
cumsum
(
num_logits
,
out
=
cu_num_logits_np
[
1
:])
cu_num_logits
=
self
.
tmp_cu_num_logits
.
copy_to_gpu
(
cu_num_logits_np
)
expanded_idx_mapping
=
expand_idx_mapping
(
idx_mapping
,
total_num_logits
,
cu_num_logits
,
max_expand_len
=
self
.
num_speculative_steps
+
1
,
)
)
cu_num_logits
=
self
.
input_buffers
.
cu_num_logits
.
copy_to_gpu
(
num_reqs
+
1
)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
block_tables
=
self
.
block_tables
.
gather_block_tables
(
idx_mapping
)
# Get query_start_loc.
# Get query_start_loc.
np
.
cumsum
(
query_start_loc_np
=
np
.
empty
(
self
.
max_num_reqs
+
1
,
dtype
=
np
.
int32
)
num_scheduled_tokens
,
query_start_loc_np
[
0
]
=
0
out
=
self
.
input_buffers
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
],
np
.
cumsum
(
num_scheduled_tokens
,
out
=
query_start_loc_np
[
1
:
num_reqs
+
1
])
)
# Pad for full CUDA graph mode.
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
self
.
input_buffers
.
query_start_loc
.
np
[
num_reqs
+
1
:]
=
num_tokens
query_start_loc_np
[
num_reqs
+
1
:]
=
num_tokens
self
.
input_buffers
.
query_start_loc
.
copy_to_gpu
()
self
.
tmp_query_start_loc
.
copy_to_gpu
(
query_start_loc_gpu
=
self
.
input_buffers
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
query_start_loc_np
,
query_start_loc_cpu
=
self
.
input_buffers
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
out
=
self
.
input_buffers
.
query_start_loc
,
query_start_loc_np
=
self
.
input_buffers
.
query_start_loc
.
np
[:
num_reqs
+
1
]
)
query_start_loc_np
=
query_start_loc_np
[:
num_reqs
+
1
]
query_start_loc_cpu
=
torch
.
from_numpy
(
query_start_loc_np
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
# Get prefill tokens.
# Get prefill tokens.
prepare_prefill_inputs
(
prepare_prefill_inputs
(
self
.
input_buffers
.
input_ids
,
self
.
input_buffers
.
input_ids
,
self
.
req_states
.
next_prefill_tokens
,
self
.
req_states
.
next_prefill_tokens
,
idx_mapping
,
idx_mapping
,
query_start_loc
_gpu
,
query_start_loc
,
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
prefill_token_ids
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
)
)
# Prepare positions and seq_lens.
# Prepare positions and seq_lens.
prepare_pos_seq_lens
(
prepare_pos_seq_lens
(
idx_mapping
,
idx_mapping
,
query_start_loc
_gpu
,
query_start_loc
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
input_buffers
.
positions
,
self
.
input_buffers
.
positions
,
self
.
input_buffers
.
seq_lens
,
self
.
input_buffers
.
seq_lens
,
)
)
...
@@ -505,7 +517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -505,7 +517,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
input_buffers
.
input_ids
,
self
.
input_buffers
.
input_ids
,
idx_mapping
,
idx_mapping
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
query_start_loc
_gpu
,
query_start_loc
,
seq_lens
,
seq_lens
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
prefill_len
.
gpu
,
self
.
req_states
.
draft_tokens
,
self
.
req_states
.
draft_tokens
,
...
@@ -515,7 +527,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -515,7 +527,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
_gpu
,
self
.
input_buffers
.
positions
[:
num_tokens
]
query_start_loc
,
self
.
input_buffers
.
positions
[:
num_tokens
]
)
)
# Layer name -> attention metadata.
# Layer name -> attention metadata.
...
@@ -523,7 +535,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -523,7 +535,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders
=
self
.
attn_metadata_builders
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
query_start_loc_gpu
=
query_start_loc
_gpu
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
input_buffers
.
seq_lens
,
seq_lens
=
self
.
input_buffers
.
seq_lens
,
max_seq_len
=
self
.
max_model_len
,
max_seq_len
=
self
.
max_model_len
,
...
@@ -539,11 +551,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -539,11 +551,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
idx_mapping
=
idx_mapping
,
idx_mapping
=
idx_mapping
,
idx_mapping_np
=
idx_mapping_np
,
idx_mapping_np
=
idx_mapping_np
,
expanded_idx_mapping
=
expanded_idx_mapping
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_tokens_after_padding
=
num_tokens_after_padding
,
num_tokens_after_padding
=
num_tokens_after_padding
,
num_draft_tokens
=
total_num_draft_tokens
,
num_draft_tokens
=
total_num_draft_tokens
,
query_start_loc
=
query_start_loc
_gpu
,
query_start_loc
=
query_start_loc
,
query_start_loc_np
=
query_start_loc_np
,
query_start_loc_np
=
query_start_loc_np
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -551,6 +564,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -551,6 +564,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
logits_indices
=
logits_indices
,
logits_indices
=
logits_indices
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits
=
cu_num_logits
,
cu_num_logits_np
=
cu_num_logits_np
,
)
)
def
sample
(
def
sample
(
...
@@ -564,16 +578,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -564,16 +578,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
self
.
structured_outputs_worker
.
apply_grammar_bitmask
(
assert
input_batch
.
num_draft_tokens
==
0
logits
,
with
async_barrier
(
self
.
structured_outputs_event
):
input_batch
,
apply_grammar_bitmask
(
grammar_output
.
structured_output_request_ids
,
logits
,
grammar_output
.
grammar_bitmask
,
input_batch
.
req_ids
,
)
grammar_output
.
structured_output_request_ids
,
grammar_output
.
grammar_bitmask
,
self
.
input_buffers
,
)
# Sample tokens and compute logprobs (if needed).
# Sample tokens and compute logprobs (if needed).
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampler_output
=
self
.
sampler
(
logits
,
sampling_metadata
)
...
@@ -641,8 +651,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -641,8 +651,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Handle chunked prompts.
# Handle chunked prompts.
pos_after_step
=
computed_prefill
+
input_batch
.
num_scheduled_tokens
pos_after_step
=
computed_prefill
+
input_batch
.
num_scheduled_tokens
is_prompt_chunked
=
pos_after_step
<
prompt_lens
is_prompt_chunked
=
pos_after_step
<
prompt_lens
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
.
np
prefill_token_ids
=
self
.
req_states
.
prefill_token_ids
.
gpu
query_start_loc
=
self
.
input_b
uffers
.
query_start_loc
.
np
query_start_loc
_np
=
input_b
atch
.
query_start_loc
_
np
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
if
not
needs_prompt_logprobs
[
i
]:
if
not
needs_prompt_logprobs
[
i
]:
continue
continue
...
@@ -650,10 +660,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -650,10 +660,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
continue
# The prompt is chunked. Get the next prompt token.
# The prompt is chunked. Get the next prompt token.
req_idx
=
input_batch
.
idx_mapping_np
[
i
]
req_idx
=
input_batch
.
idx_mapping_np
[
i
]
next_prompt_token
=
int
(
prefill_token_ids
[
req_idx
,
pos_after_step
[
i
]])
idx
=
int
(
query_start_loc_np
[
i
+
1
]
-
1
)
idx
=
int
(
query_start_loc
[
i
+
1
]
-
1
)
# NOTE(woosuk): This triggers two GPU operations.
# Set the next prompt token.
next_prompt_token
=
prefill_token_ids
[
req_idx
,
pos_after_step
[
i
]]
# NOTE(woosuk): This triggers a GPU operation.
token_ids
[
idx
]
=
next_prompt_token
token_ids
[
idx
]
=
next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
# NOTE(woosuk): We mask out logprobs for negative tokens.
...
@@ -669,8 +678,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -669,8 +678,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
not
needs_prompt_logprobs
[
i
]:
if
not
needs_prompt_logprobs
[
i
]:
continue
continue
start_idx
=
query_start_loc
[
i
]
start_idx
=
query_start_loc
_np
[
i
]
end_idx
=
query_start_loc
[
i
+
1
]
end_idx
=
query_start_loc
_np
[
i
+
1
]
assert
start_idx
<
end_idx
,
(
assert
start_idx
<
end_idx
,
(
f
"start_idx (
{
start_idx
}
) >= end_idx (
{
end_idx
}
)"
f
"start_idx (
{
start_idx
}
) >= end_idx (
{
end_idx
}
)"
)
)
...
@@ -714,7 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -714,7 +723,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the number of computed tokens.
# Update the number of computed tokens.
post_update
(
post_update
(
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
last_sampled_tokens
,
self
.
req_states
.
output_bin_counts
,
self
.
req_states
.
output_bin_counts
,
sampled_tokens
,
sampled_tokens
,
...
@@ -825,61 +834,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -825,61 +834,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert
intermediate_tensors
is
None
assert
intermediate_tensors
is
None
if
scheduler_output
.
total_num_scheduled_tokens
==
0
and
not
dummy_run
:
if
scheduler_output
.
total_num_scheduled_tokens
==
0
and
not
dummy_run
:
# No need to run the model.
# No need to run the model.
with
async_barrier
(
self
.
input_prep_event
):
self
.
update_states
(
scheduler_output
)
self
.
update_states
(
scheduler_output
)
return
EMPTY_MODEL_RUNNER_OUTPUT
return
EMPTY_MODEL_RUNNER_OUTPUT
# NOTE: Call this before the async barrier so CPU all-reduce and
# GPU execution can overlap.
cudagraph_mode
,
num_tokens_after_padding
,
num_tokens_across_dp
=
(
cudagraph_mode
,
num_tokens_after_padding
,
num_tokens_across_dp
=
(
self
.
get_cudagraph_and_dp_padding
(
scheduler_output
)
self
.
get_cudagraph_and_dp_padding
(
scheduler_output
)
)
)
with
async_barrier
(
self
.
input_prep_event
):
self
.
update_states
(
scheduler_output
)
self
.
update_states
(
scheduler_output
)
if
num_tokens_after_padding
==
0
:
if
num_tokens_after_padding
==
0
:
# All DP ranks have zero tokens to run.
# All DP ranks have zero tokens to run.
return
EMPTY_MODEL_RUNNER_OUTPUT
return
EMPTY_MODEL_RUNNER_OUTPUT
if
not
dummy_run
:
if
not
dummy_run
:
# Common case.
# Common case.
# Prepare all the inputs and copy to the input buffers.
# Prepare all the inputs and copy to the input buffers.
input_batch
=
self
.
prepare_inputs
(
input_batch
=
self
.
prepare_inputs
(
scheduler_output
,
scheduler_output
,
num_tokens_after_padding
,
num_tokens_after_padding
,
)
)
# NOTE(woosuk): Sampling metadata should be built under the async
pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
# barrier to avoid race conditions.
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
pos
=
input_batch
.
positions
[
input_batch
.
logits_indices
]
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
sampling_metadata
=
self
.
req_states
.
make_sampling_metadata
(
)
input_batch
.
idx_mapping
,
input_batch
.
idx_mapping_np
,
pos
)
if
self
.
lora_config
:
if
input_batch
.
num_draft_tokens
>
0
:
# Activate LoRA adapters.
sampling_metadata
=
expand_sampling_metadata
(
lora_inputs
=
self
.
req_states
.
make_lora_inputs
(
sampling_metadata
,
input_batch
.
req_ids
,
input_batch
.
cu_num_logits
,
input_batch
.
idx_mapping_np
,
max_expand_len
=
self
.
num_speculative_steps
+
1
,
input_batch
.
num_scheduled_tokens
,
)
if
self
.
lora_config
:
# Activate LoRA adapters.
lora_inputs
=
self
.
req_states
.
make_lora_inputs
(
input_batch
.
req_ids
,
input_batch
.
idx_mapping_np
,
input_batch
.
num_scheduled_tokens
,
)
self
.
_set_active_loras
(
*
lora_inputs
)
else
:
# No actual tokens to run. A dummy run for DP.
num_reqs
=
min
(
num_tokens_after_padding
,
self
.
max_num_reqs
)
input_batch
=
InputBatch
.
make_dummy
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens_after_padding
,
input_buffers
=
self
.
input_buffers
,
device
=
self
.
device
,
)
)
self
.
prepare_dummy_attn_metadata
(
input_batch
)
self
.
_set_active_loras
(
*
lora_inputs
)
sampling_metadata
=
None
else
:
# No actual tokens to run. A dummy run for DP.
num_reqs
=
min
(
num_tokens_after_padding
,
self
.
max_num_reqs
)
input_batch
=
InputBatch
.
make_dummy
(
num_reqs
=
num_reqs
,
num_tokens
=
num_tokens_after_padding
,
input_buffers
=
self
.
input_buffers
,
device
=
self
.
device
,
)
self
.
prepare_dummy_attn_metadata
(
input_batch
)
sampling_metadata
=
None
# Run model.
# Run model.
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
if
cudagraph_mode
==
CUDAGraphMode
.
FULL
:
...
...
vllm/v1/worker/gpu/sample/gumbel.py
View file @
025a32f9
...
@@ -13,6 +13,7 @@ def _gumbel_sample_kernel(
...
@@ -13,6 +13,7 @@ def _gumbel_sample_kernel(
local_max_stride
,
local_max_stride
,
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
seeds_ptr
,
seeds_ptr
,
pos_ptr
,
pos_ptr
,
temp_ptr
,
temp_ptr
,
...
@@ -20,22 +21,24 @@ def _gumbel_sample_kernel(
...
@@ -20,22 +21,24 @@ def _gumbel_sample_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
APPLY_TEMPERATURE
:
tl
.
constexpr
,
):
):
req_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
block_idx
=
tl
.
program_id
(
1
)
block_idx
=
tl
.
program_id
(
1
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
block
=
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
vocab_size
mask
=
block
<
vocab_size
logits
=
tl
.
load
(
logits
=
tl
.
load
(
logits_ptr
+
req
_idx
*
logits_stride
+
block
,
logits_ptr
+
batch
_idx
*
logits_stride
+
block
,
mask
=
mask
,
mask
=
mask
,
other
=
float
(
"-inf"
),
other
=
float
(
"-inf"
),
)
)
logits
=
logits
.
to
(
tl
.
float32
)
logits
=
logits
.
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_idx
).
to
(
tl
.
float32
)
temp
=
tl
.
load
(
temp_ptr
+
req_
state_
idx
).
to
(
tl
.
float32
)
if
temp
!=
0.0
:
if
temp
!=
0.0
:
# Calculate the seed for gumbel noise.
# Calculate the seed for gumbel noise.
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
seed
=
tl
.
load
(
seeds_ptr
+
req_
state_
idx
)
pos
=
tl
.
load
(
pos_ptr
+
req
_idx
)
pos
=
tl
.
load
(
pos_ptr
+
batch
_idx
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
gumbel_seed
=
tl
.
randint
(
seed
,
pos
)
# Generate gumbel noise.
# Generate gumbel noise.
...
@@ -55,12 +58,13 @@ def _gumbel_sample_kernel(
...
@@ -55,12 +58,13 @@ def _gumbel_sample_kernel(
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
idx
=
tl
.
argmax
(
logits
,
axis
=
0
)
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
token_id
=
block_idx
*
BLOCK_SIZE
+
idx
value
=
tl
.
max
(
logits
,
axis
=
0
)
value
=
tl
.
max
(
logits
,
axis
=
0
)
tl
.
store
(
local_argmax_ptr
+
req
_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_argmax_ptr
+
batch
_idx
*
local_argmax_stride
+
block_idx
,
token_id
)
tl
.
store
(
local_max_ptr
+
req
_idx
*
local_max_stride
+
block_idx
,
value
)
tl
.
store
(
local_max_ptr
+
batch
_idx
*
local_max_stride
+
block_idx
,
value
)
def
gumbel_sample
(
def
gumbel_sample
(
logits
:
torch
.
Tensor
,
# [num_reqs, vocab_size]
logits
:
torch
.
Tensor
,
# [num_reqs, vocab_size]
idx_mapping
:
torch
.
Tensor
,
# [num_reqs]
temperature
:
torch
.
Tensor
,
# [num_reqs]
temperature
:
torch
.
Tensor
,
# [num_reqs]
seed
:
torch
.
Tensor
,
# [num_reqs]
seed
:
torch
.
Tensor
,
# [num_reqs]
pos
:
torch
.
Tensor
,
# [num_reqs]
pos
:
torch
.
Tensor
,
# [num_reqs]
...
@@ -88,6 +92,7 @@ def gumbel_sample(
...
@@ -88,6 +92,7 @@ def gumbel_sample(
local_max
.
stride
(
0
),
local_max
.
stride
(
0
),
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
seed
,
seed
,
pos
,
pos
,
temperature
,
temperature
,
...
...
vllm/v1/worker/gpu/sample/metadata.py
View file @
025a32f9
...
@@ -4,20 +4,23 @@ from dataclasses import dataclass
...
@@ -4,20 +4,23 @@ from dataclasses import dataclass
import
torch
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
dataclass
@
dataclass
class
SamplingMetadata
:
class
SamplingMetadata
:
idx_mapping
:
torch
.
Tensor
temperature
:
torch
.
Tensor
temperature
:
torch
.
Tensor
top_p
:
torch
.
Tensor
|
None
top_p
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
top_k
:
torch
.
Tensor
|
None
min_p
:
torch
.
Tensor
|
None
min_p
:
torch
.
Tensor
|
None
# For penalties
repetition_penalty
:
torch
.
Tensor
repetition_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
frequency_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
presence_penalty
:
torch
.
Tensor
prompt_bin_mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
seeds
:
torch
.
Tensor
seeds
:
torch
.
Tensor
pos
:
torch
.
Tensor
pos
:
torch
.
Tensor
...
@@ -25,11 +28,6 @@ class SamplingMetadata:
...
@@ -25,11 +28,6 @@ class SamplingMetadata:
# None means no logprobs, 0 means sampled token logprobs only
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs
:
int
|
None
max_num_logprobs
:
int
|
None
# For penalties
idx_mapping
:
torch
.
Tensor
prompt_bin_mask
:
torch
.
Tensor
output_bin_counts
:
torch
.
Tensor
@
classmethod
@
classmethod
def
make_dummy
(
def
make_dummy
(
cls
,
cls
,
...
@@ -37,6 +35,8 @@ class SamplingMetadata:
...
@@ -37,6 +35,8 @@ class SamplingMetadata:
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
"SamplingMetadata"
:
)
->
"SamplingMetadata"
:
assert
num_reqs
>
0
assert
num_reqs
>
0
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
temperature
[
0
]
=
0.5
temperature
[
0
]
=
0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# TODO(woosuk): Use top-p and top-k for dummy sampler.
...
@@ -51,18 +51,19 @@ class SamplingMetadata:
...
@@ -51,18 +51,19 @@ class SamplingMetadata:
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
repetition_penalty
=
torch
.
ones
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
frequency_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
presence_penalty
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
float32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
idx_mapping
=
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
# specialization and re-compilation at runtime.
prompt_bin_mask
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
prompt_bin_mask
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
output_bin_counts
=
torch
.
zeros
(
num_reqs
,
2
,
dtype
=
torch
.
int32
,
device
=
device
)
seeds
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
pos
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int64
,
device
=
device
)
max_num_logprobs
=
20
return
cls
(
return
cls
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
temperature
=
temperature
,
top_p
=
top_p
,
top_p
=
top_p
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -70,123 +71,9 @@ class SamplingMetadata:
...
@@ -70,123 +71,9 @@ class SamplingMetadata:
repetition_penalty
=
repetition_penalty
,
repetition_penalty
=
repetition_penalty
,
frequency_penalty
=
frequency_penalty
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
presence_penalty
=
presence_penalty
,
prompt_bin_mask
=
prompt_bin_mask
,
output_bin_counts
=
output_bin_counts
,
seeds
=
seeds
,
seeds
=
seeds
,
pos
=
pos
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_mask
=
prompt_bin_mask
,
output_bin_counts
=
output_bin_counts
,
)
)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@
triton
.
jit
def
_expand_sampling_metadata_kernel
(
temp_ptr
,
expanded_temp_ptr
,
top_p_ptr
,
expanded_top_p_ptr
,
top_k_ptr
,
expanded_top_k_ptr
,
min_p_ptr
,
expanded_min_p_ptr
,
rep_penalty_ptr
,
expanded_rep_penalty_ptr
,
freq_penalty_ptr
,
expanded_freq_penalty_ptr
,
pres_penalty_ptr
,
expanded_pres_penalty_ptr
,
seeds_ptr
,
expanded_seeds_ptr
,
cu_num_logits_ptr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
)
end_idx
=
tl
.
load
(
cu_num_logits_ptr
+
req_idx
+
1
)
num_tokens
=
end_idx
-
start_idx
block
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
block
<
num_tokens
temp
=
tl
.
load
(
temp_ptr
+
req_idx
)
tl
.
store
(
expanded_temp_ptr
+
start_idx
+
block
,
temp
,
mask
=
mask
)
if
top_p_ptr
is
not
None
:
top_p
=
tl
.
load
(
top_p_ptr
+
req_idx
)
tl
.
store
(
expanded_top_p_ptr
+
start_idx
+
block
,
top_p
,
mask
=
mask
)
if
top_k_ptr
is
not
None
:
top_k
=
tl
.
load
(
top_k_ptr
+
req_idx
)
tl
.
store
(
expanded_top_k_ptr
+
start_idx
+
block
,
top_k
,
mask
=
mask
)
if
min_p_ptr
is
not
None
:
min_p
=
tl
.
load
(
min_p_ptr
+
req_idx
)
tl
.
store
(
expanded_min_p_ptr
+
start_idx
+
block
,
min_p
,
mask
=
mask
)
rep_penalty
=
tl
.
load
(
rep_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_rep_penalty_ptr
+
start_idx
+
block
,
rep_penalty
,
mask
=
mask
)
freq_penalty
=
tl
.
load
(
freq_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_freq_penalty_ptr
+
start_idx
+
block
,
freq_penalty
,
mask
=
mask
)
pres_penalty
=
tl
.
load
(
pres_penalty_ptr
+
req_idx
)
tl
.
store
(
expanded_pres_penalty_ptr
+
start_idx
+
block
,
pres_penalty
,
mask
=
mask
)
seed
=
tl
.
load
(
seeds_ptr
+
req_idx
)
tl
.
store
(
expanded_seeds_ptr
+
start_idx
+
block
,
seed
,
mask
=
mask
)
def
expand_sampling_metadata
(
sampling_metadata
:
SamplingMetadata
,
cu_num_logits
:
torch
.
Tensor
,
max_expand_len
:
int
,
)
->
SamplingMetadata
:
total_num_logits
=
sampling_metadata
.
pos
.
shape
[
0
]
create_empty
=
lambda
x
:
x
.
new_empty
(
total_num_logits
)
if
x
is
not
None
else
None
expanded_temp
=
create_empty
(
sampling_metadata
.
temperature
)
expanded_top_p
=
create_empty
(
sampling_metadata
.
top_p
)
expanded_top_k
=
create_empty
(
sampling_metadata
.
top_k
)
expanded_min_p
=
create_empty
(
sampling_metadata
.
min_p
)
expanded_repetition_penalty
=
create_empty
(
sampling_metadata
.
repetition_penalty
)
expanded_frequency_penalty
=
create_empty
(
sampling_metadata
.
frequency_penalty
)
expanded_presence_penalty
=
create_empty
(
sampling_metadata
.
presence_penalty
)
expanded_seeds
=
create_empty
(
sampling_metadata
.
seeds
)
num_reqs
=
cu_num_logits
.
shape
[
0
]
-
1
_expand_sampling_metadata_kernel
[(
num_reqs
,)](
sampling_metadata
.
temperature
,
expanded_temp
,
sampling_metadata
.
top_p
,
expanded_top_p
,
sampling_metadata
.
top_k
,
expanded_top_k
,
sampling_metadata
.
min_p
,
expanded_min_p
,
sampling_metadata
.
repetition_penalty
,
expanded_repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
expanded_frequency_penalty
,
sampling_metadata
.
presence_penalty
,
expanded_presence_penalty
,
sampling_metadata
.
seeds
,
expanded_seeds
,
cu_num_logits
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
max_expand_len
),
)
return
SamplingMetadata
(
temperature
=
expanded_temp
,
top_p
=
expanded_top_p
,
top_k
=
expanded_top_k
,
min_p
=
expanded_min_p
,
seeds
=
expanded_seeds
,
repetition_penalty
=
expanded_repetition_penalty
,
frequency_penalty
=
expanded_frequency_penalty
,
presence_penalty
=
expanded_presence_penalty
,
pos
=
sampling_metadata
.
pos
,
max_num_logprobs
=
sampling_metadata
.
max_num_logprobs
,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping
=
sampling_metadata
.
idx_mapping
,
prompt_bin_mask
=
sampling_metadata
.
prompt_bin_mask
,
output_bin_counts
=
sampling_metadata
.
output_bin_counts
,
)
vllm/v1/worker/gpu/sample/min_p.py
View file @
025a32f9
...
@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton
...
@@ -9,12 +9,14 @@ from vllm.triton_utils import tl, triton
def
_min_p_kernel
(
def
_min_p_kernel
(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
min_p_ptr
,
min_p_ptr
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
req_idx
=
tl
.
program_id
(
0
)
req_idx
=
tl
.
program_id
(
0
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_idx
).
to
(
tl
.
float32
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
req_idx
)
min_p
=
tl
.
load
(
min_p_ptr
+
req_state_idx
).
to
(
tl
.
float32
)
if
min_p
==
0.0
:
if
min_p
==
0.0
:
return
return
...
@@ -39,12 +41,17 @@ def _min_p_kernel(
...
@@ -39,12 +41,17 @@ def _min_p_kernel(
tl
.
store
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
tl
.
store
(
logits_ptr
+
req_idx
*
logits_stride
+
block
,
logits
,
mask
=
mask
)
def
apply_min_p
(
logits
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
)
->
None
:
def
apply_min_p
(
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
min_p
:
torch
.
Tensor
,
)
->
None
:
num_reqs
,
vocab_size
=
logits
.
shape
num_reqs
,
vocab_size
=
logits
.
shape
BLOCK_SIZE
=
1024
BLOCK_SIZE
=
1024
_min_p_kernel
[(
num_reqs
,)](
_min_p_kernel
[(
num_reqs
,)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
idx_mapping
,
min_p
,
min_p
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
...
...
vllm/v1/worker/gpu/sample/penalties.py
View file @
025a32f9
...
@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
...
@@ -10,11 +10,11 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
def
_penalties_and_temperature_kernel
(
def
_penalties_and_temperature_kernel
(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
idx_mapping_ptr
,
repetition_penalty_ptr
,
repetition_penalty_ptr
,
frequency_penalty_ptr
,
frequency_penalty_ptr
,
presence_penalty_ptr
,
presence_penalty_ptr
,
temperature_ptr
,
temperature_ptr
,
idx_mapping_ptr
,
prompt_bin_mask_ptr
,
prompt_bin_mask_ptr
,
prompt_bin_mask_stride
,
prompt_bin_mask_stride
,
output_bin_counts_ptr
,
output_bin_counts_ptr
,
...
@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel(
...
@@ -23,10 +23,11 @@ def _penalties_and_temperature_kernel(
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
batch_idx
=
tl
.
program_id
(
0
)
batch_idx
=
tl
.
program_id
(
0
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
batch_idx
)
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
batch_idx
)
rep_penalty
=
tl
.
load
(
repetition_penalty_ptr
+
req_state_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
batch_idx
)
freq_penalty
=
tl
.
load
(
frequency_penalty_ptr
+
req_state_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
batch_idx
)
pres_penalty
=
tl
.
load
(
presence_penalty_ptr
+
req_state_idx
)
temperature
=
tl
.
load
(
temperature_ptr
+
req_state_idx
)
temperature
=
tl
.
where
(
temperature
==
0.0
,
1.0
,
temperature
)
temperature
=
tl
.
where
(
temperature
==
0.0
,
1.0
,
temperature
)
use_rep_penalty
=
rep_penalty
!=
1.0
use_rep_penalty
=
rep_penalty
!=
1.0
...
@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel(
...
@@ -45,7 +46,6 @@ def _penalties_and_temperature_kernel(
logits
=
logits
.
to
(
tl
.
float32
)
logits
=
logits
.
to
(
tl
.
float32
)
if
use_penalty
:
if
use_penalty
:
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_idx
)
output_bin_counts
=
tl
.
load
(
output_bin_counts
=
tl
.
load
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
block
,
mask
=
mask
,
mask
=
mask
,
...
@@ -92,11 +92,11 @@ def apply_penalties_and_temperature(
...
@@ -92,11 +92,11 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel
[(
num_reqs
,
num_blocks
)](
_penalties_and_temperature_kernel
[(
num_reqs
,
num_blocks
)](
logits
,
logits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
repetition_penalty
,
sampling_metadata
.
repetition_penalty
,
sampling_metadata
.
frequency_penalty
,
sampling_metadata
.
frequency_penalty
,
sampling_metadata
.
presence_penalty
,
sampling_metadata
.
presence_penalty
,
sampling_metadata
.
temperature
,
sampling_metadata
.
temperature
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
prompt_bin_mask
,
sampling_metadata
.
prompt_bin_mask
,
sampling_metadata
.
prompt_bin_mask
.
stride
(
0
),
sampling_metadata
.
prompt_bin_mask
.
stride
(
0
),
sampling_metadata
.
output_bin_counts
,
sampling_metadata
.
output_bin_counts
,
...
...
vllm/v1/worker/gpu/sample/sampler.py
View file @
025a32f9
...
@@ -71,7 +71,7 @@ class Sampler:
...
@@ -71,7 +71,7 @@ class Sampler:
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
apply_penalties_and_temperature
(
logits
,
sampling_metadata
)
# Apply min_p in place.
# Apply min_p in place.
if
sampling_metadata
.
min_p
is
not
None
:
if
sampling_metadata
.
min_p
is
not
None
:
apply_min_p
(
logits
,
sampling_metadata
.
min_p
)
apply_min_p
(
logits
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
min_p
)
# Apply top_k and/or top_p. This might return a new tensor.
# Apply top_k and/or top_p. This might return a new tensor.
logits
=
apply_top_k_top_p
(
logits
=
apply_top_k_top_p
(
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
logits
,
sampling_metadata
.
top_k
,
sampling_metadata
.
top_p
...
@@ -79,6 +79,7 @@ class Sampler:
...
@@ -79,6 +79,7 @@ class Sampler:
sampled
=
gumbel_sample
(
sampled
=
gumbel_sample
(
logits
,
logits
,
sampling_metadata
.
idx_mapping
,
sampling_metadata
.
temperature
,
sampling_metadata
.
temperature
,
sampling_metadata
.
seeds
,
sampling_metadata
.
seeds
,
sampling_metadata
.
pos
,
sampling_metadata
.
pos
,
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
025a32f9
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context
...
@@ -12,7 +11,6 @@ from vllm.forward_context import set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
...
@@ -46,7 +44,6 @@ class EagleSpeculator:
...
@@ -46,7 +44,6 @@ class EagleSpeculator:
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
hidden_size
=
self
.
draft_model_config
.
get_hidden_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
self
.
inputs_embeds_size
=
self
.
draft_model_config
.
get_inputs_embeds_size
()
self
.
vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
self
.
vocab_size
=
self
.
draft_model_config
.
get_vocab_size
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
input_buffers
=
InputBuffers
(
self
.
input_buffers
=
InputBuffers
(
...
@@ -56,7 +53,6 @@ class EagleSpeculator:
...
@@ -56,7 +53,6 @@ class EagleSpeculator:
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
,
device
=
device
,
pin_memory
=
self
.
pin_memory
,
)
)
self
.
hidden_states
=
torch
.
zeros
(
self
.
hidden_states
=
torch
.
zeros
(
self
.
max_num_tokens
,
self
.
max_num_tokens
,
...
@@ -64,6 +60,11 @@ class EagleSpeculator:
...
@@ -64,6 +60,11 @@ class EagleSpeculator:
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
,
device
=
device
,
)
)
self
.
idx_mapping
=
torch
.
zeros
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
temperature
=
torch
.
zeros
(
self
.
temperature
=
torch
.
zeros
(
self
.
max_num_reqs
,
self
.
max_num_reqs
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -140,7 +141,7 @@ class EagleSpeculator:
...
@@ -140,7 +141,7 @@ class EagleSpeculator:
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
num_tokens_across_dp
:
torch
.
Tensor
|
None
,
)
->
None
:
)
->
None
:
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
for
step
in
range
(
1
,
self
.
num_speculative_steps
):
# Run the eagle model.
# Run the eagle model.
last_hidden_states
,
hidden_states
=
self
.
run_model
(
last_hidden_states
,
hidden_states
=
self
.
run_model
(
...
@@ -152,8 +153,9 @@ class EagleSpeculator:
...
@@ -152,8 +153,9 @@ class EagleSpeculator:
# used for draft and target sampling.
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
draft_tokens
=
gumbel_sample
(
logits
,
logits
,
self
.
temperature
[:
num_reqs
],
self
.
idx_mapping
[:
num_reqs
],
self
.
seeds
[:
num_reqs
],
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
pos
+
1
,
apply_temperature
=
True
,
apply_temperature
=
True
,
)
)
...
@@ -237,23 +239,27 @@ class EagleSpeculator:
...
@@ -237,23 +239,27 @@ class EagleSpeculator:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
num_reqs
=
input_batch
.
num_reqs
num_reqs
=
input_batch
.
num_reqs
cu_num_logits
=
input_batch
.
cu_num_logits
[:
num_reqs
]
# NOTE(woosuk): For draft sampling, we only consider the temperature
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
# affect the output distribution after rejection sampling.
temperature
=
self
.
temperature
[:
num_reqs
]
idx_mapping
=
self
.
idx_mapping
[:
num_reqs
]
seeds
=
self
.
seeds
[:
num_reqs
]
idx_mapping
.
copy_
(
input_batch
.
idx_mapping
)
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
self
.
temperature
.
copy_
(
sampling_metadata
.
temperature
)
self
.
seeds
.
copy_
(
sampling_metadata
.
seeds
)
# Gather the values and copy them to the pre-allocated buffers.
# Gather the values and copy them to the pre-allocated buffers.
torch
.
gather
(
sampling_metadata
.
temperature
,
0
,
cu_num_logits
,
out
=
temperature
)
pos
=
self
.
input_buffers
.
positions
[:
num_reqs
]
torch
.
gather
(
sampling_metadata
.
seeds
,
0
,
cu_num_logits
,
out
=
seeds
)
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
torch
.
gather
(
input_batch
.
positions
,
0
,
last_token_indices
,
out
=
pos
)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
# used for draft and target sampling.
draft_tokens
=
gumbel_sample
(
draft_tokens
=
gumbel_sample
(
logits
,
temperature
,
seeds
,
pos
+
1
,
apply_temperature
=
True
logits
,
idx_mapping
,
self
.
temperature
,
self
.
seeds
,
pos
+
1
,
apply_temperature
=
True
,
)
)
if
self
.
num_speculative_steps
==
1
:
if
self
.
num_speculative_steps
==
1
:
# Early exit.
# Early exit.
...
@@ -273,11 +279,8 @@ class EagleSpeculator:
...
@@ -273,11 +279,8 @@ class EagleSpeculator:
self
.
max_model_len
,
self
.
max_model_len
,
self
.
max_num_reqs
,
self
.
max_num_reqs
,
)
)
query_start_loc
=
self
.
input_buffers
.
query_start_loc
query_start_loc
=
self
.
input_buffers
.
query_start_loc
[:
num_reqs
+
1
]
query_start_loc_gpu
=
query_start_loc
.
gpu
[:
num_reqs
+
1
]
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc
,
pos
)
slot_mappings
=
self
.
block_tables
.
compute_slot_mappings
(
query_start_loc_gpu
,
pos
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
num_reqs
)
if
cudagraph_size
is
not
None
:
if
cudagraph_size
is
not
None
:
...
@@ -286,8 +289,9 @@ class EagleSpeculator:
...
@@ -286,8 +289,9 @@ class EagleSpeculator:
return
self
.
draft_tokens
[:
num_reqs
]
return
self
.
draft_tokens
[:
num_reqs
]
# Run eager mode.
# Run eager mode.
query_start_loc
.
np
[:
num_reqs
+
1
]
=
np
.
arange
(
num_reqs
+
1
)
query_start_loc_cpu
=
torch
.
arange
(
query_start_loc_cpu
=
query_start_loc
.
cpu
[:
num_reqs
+
1
]
num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
[
x
[:
num_reqs
]
for
x
in
self
.
block_tables
.
input_block_tables
]
block_tables
=
[
x
[:
num_reqs
]
for
x
in
self
.
block_tables
.
input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
# FIXME(woosuk): This is UNSAFE!!
...
@@ -295,7 +299,7 @@ class EagleSpeculator:
...
@@ -295,7 +299,7 @@ class EagleSpeculator:
attn_metadata_builders
=
self
.
attn_metadata_builders
,
attn_metadata_builders
=
self
.
attn_metadata_builders
,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_tokens
=
num_reqs
,
num_tokens
=
num_reqs
,
query_start_loc_gpu
=
query_start_loc
_gpu
,
query_start_loc_gpu
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
seq_lens
=
self
.
input_buffers
.
seq_lens
[:
num_reqs
],
max_seq_len
=
self
.
max_model_len
,
max_seq_len
=
self
.
max_model_len
,
...
@@ -484,7 +488,7 @@ def prepare_eagle_decode(
...
@@ -484,7 +488,7 @@ def prepare_eagle_decode(
input_buffers
.
positions
,
input_buffers
.
positions
,
input_hidden_states
,
input_hidden_states
,
input_hidden_states
.
stride
(
0
),
input_hidden_states
.
stride
(
0
),
input_buffers
.
query_start_loc
.
gpu
,
input_buffers
.
query_start_loc
,
input_buffers
.
seq_lens
,
input_buffers
.
seq_lens
,
hidden_size
,
hidden_size
,
max_model_len
,
max_model_len
,
...
...
vllm/v1/worker/gpu/states.py
View file @
025a32f9
...
@@ -8,10 +8,8 @@ import torch
...
@@ -8,10 +8,8 @@ import torch
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_uva_available
from
vllm.utils.torch_utils
import
get_cuda_view_from_cpu_tensor
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.
utils
import
CpuGpuBuffe
r
from
vllm.v1.
worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTenso
r
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu.sample.penalties
import
bincount
from
vllm.v1.worker.gpu.sample.penalties
import
bincount
...
@@ -29,7 +27,6 @@ class RequestState:
...
@@ -29,7 +27,6 @@ class RequestState:
num_speculative_steps
:
int
,
num_speculative_steps
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
...
@@ -37,7 +34,6 @@ class RequestState:
...
@@ -37,7 +34,6 @@ class RequestState:
self
.
num_speculative_steps
=
num_speculative_steps
self
.
num_speculative_steps
=
num_speculative_steps
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
self
.
index_to_req_id
:
dict
[
int
,
str
]
=
{}
self
.
index_to_req_id
:
dict
[
int
,
str
]
=
{}
...
@@ -47,16 +43,18 @@ class RequestState:
...
@@ -47,16 +43,18 @@ class RequestState:
self
.
prompt_len
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
prompt_len
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
# depending on the configured max_num_reqs and max_model_len.
self
.
prefill_token_ids
=
UvaBuffer
(
# To save GPU memory, we use UVA instead of GPU for this tensor.
self
.
max_num_reqs
,
self
.
max_model_len
,
dtype
=
torch
.
int32
self
.
prefill_token_ids
=
StagedWriteTensor
(
(
self
.
max_num_reqs
,
self
.
max_model_len
),
dtype
=
torch
.
int32
,
device
=
device
,
uva_instead_of_gpu
=
True
,
)
)
# NOTE(woosuk): We don't use UVA for prefill_len because its GPU view
self
.
prefill_len
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
# can be used outside of update_states and prepare_inputs.
# Without async barrier, using UVA can cause race conditions.
self
.
prefill_len
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
# Number of computed tokens.
# Number of computed tokens.
self
.
num_computed_prefill_tokens
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_prefill_tokens
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens
=
torch
.
zeros
(
self
.
num_computed_tokens
=
StagedWriteTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
)
...
@@ -84,14 +82,16 @@ class RequestState:
...
@@ -84,14 +82,16 @@ class RequestState:
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
# Sampling parameters.
# Sampling parameters.
self
.
temperature
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
temperature
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
top_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
top_k
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
top_k
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int32
)
self
.
min_p
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
min_p
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
repetition_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
repetition_penalty
=
UvaBackedTensor
(
self
.
frequency_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
self
.
max_num_reqs
,
dtype
=
torch
.
float32
self
.
presence_penalty
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
float32
)
)
self
.
seeds
=
self
.
_make_param
(
self
.
max_num_reqs
,
torch
.
int64
)
self
.
frequency_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
presence_penalty
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
float32
)
self
.
seeds
=
UvaBackedTensor
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_logprobs
=
np
.
empty
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# -1 means no logprobs are requested.
# -1 means no logprobs are requested.
...
@@ -111,13 +111,7 @@ class RequestState:
...
@@ -111,13 +111,7 @@ class RequestState:
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
self
.
max_num_reqs
,
self
.
vocab_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
def
_make_param
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
)
->
"Param"
:
self
.
_penalties_reqs
:
list
[
int
]
=
[]
return
Param
(
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
def
_make_buffer
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
)
@
property
@
property
def
num_reqs
(
self
)
->
int
:
def
num_reqs
(
self
)
->
int
:
...
@@ -144,12 +138,9 @@ class RequestState:
...
@@ -144,12 +138,9 @@ class RequestState:
f
"prefill_len
{
prefill_len
}
< prompt_len
{
prompt_len
}
"
f
"prefill_len
{
prefill_len
}
< prompt_len
{
prompt_len
}
"
)
)
self
.
prefill_len
.
np
[
req_idx
]
=
prefill_len
self
.
prefill_len
.
np
[
req_idx
]
=
prefill_len
self
.
prefill_token_ids
.
np
[
req_idx
,
:
prefill_len
]
=
prefill_token_ids
self
.
prefill_token_ids
.
stage_write
(
req_idx
,
0
,
prefill_token_ids
)
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
self
.
num_computed_tokens
.
stage_write_elem
(
req_idx
,
num_computed_tokens
)
# Optimize this.
self
.
num_computed_tokens
[
req_idx
]
=
num_computed_tokens
if
lora_request
is
not
None
:
if
lora_request
is
not
None
:
self
.
lora_ids
[
req_idx
]
=
lora_request
.
lora_int_id
self
.
lora_ids
[
req_idx
]
=
lora_request
.
lora_int_id
...
@@ -169,13 +160,7 @@ class RequestState:
...
@@ -169,13 +160,7 @@ class RequestState:
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
self
.
presence_penalty
.
np
[
req_idx
]
=
sampling_params
.
presence_penalty
if
use_penalty
(
sampling_params
):
if
use_penalty
(
sampling_params
):
bincount
(
self
.
_penalties_reqs
.
append
(
req_idx
)
self
.
prefill_token_ids
.
gpu
[
req_idx
],
prefill_len
,
prompt_len
,
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
if
sampling_params
.
seed
is
not
None
:
if
sampling_params
.
seed
is
not
None
:
seed
=
sampling_params
.
seed
seed
=
sampling_params
.
seed
...
@@ -193,6 +178,22 @@ class RequestState:
...
@@ -193,6 +178,22 @@ class RequestState:
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
needs_prompt_logprobs
=
sampling_params
.
prompt_logprobs
is
not
None
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
self
.
needs_prompt_logprobs
[
req_idx
]
=
needs_prompt_logprobs
def
apply_staged_writes
(
self
)
->
None
:
self
.
prefill_len
.
copy_to_uva
()
self
.
prefill_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
# TODO(woosuk): Optimize this.
for
req_idx
in
self
.
_penalties_reqs
:
bincount
(
self
.
prefill_token_ids
.
gpu
[
req_idx
],
int
(
self
.
prefill_len
.
np
[
req_idx
]),
int
(
self
.
prompt_len
[
req_idx
]),
self
.
prompt_bin_mask
[
req_idx
],
self
.
output_bin_counts
[
req_idx
],
)
self
.
_penalties_reqs
.
clear
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
extra_data
.
pop
(
req_id
,
None
)
self
.
extra_data
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
...
@@ -208,30 +209,25 @@ class RequestState:
...
@@ -208,30 +209,25 @@ class RequestState:
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
temperature
=
self
.
temperature
.
np
[
idx_mapping_np
]
temperature
=
self
.
temperature
.
copy_to_uva
()
temperature
=
self
.
temperature
.
copy_np_to_gpu
(
temperature
)
top_p
=
self
.
top_p
.
np
[
idx_mapping_np
]
top_p
=
self
.
top_p
.
np
[
idx_mapping_np
]
no_top_p
=
np
.
all
(
top_p
==
1.0
)
no_top_p
=
np
.
all
(
top_p
==
1.0
)
top_p
=
self
.
top_p
.
copy_
np_to_gpu
(
top_p
)
if
not
no_top_p
else
None
top_p
=
self
.
top_p
.
copy_
to_uva
()[
idx_mapping
]
if
not
no_top_p
else
None
top_k
=
self
.
top_k
.
np
[
idx_mapping_np
]
top_k
=
self
.
top_k
.
np
[
idx_mapping_np
]
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
no_top_k
=
np
.
all
(
top_k
==
self
.
vocab_size
)
top_k
=
self
.
top_k
.
copy_
np_to_gpu
(
top_k
)
if
not
no_top_k
else
None
top_k
=
self
.
top_k
.
copy_
to_uva
()[
idx_mapping
]
if
not
no_top_k
else
None
min_p
=
self
.
min_p
.
np
[
idx_mapping_np
]
min_p
=
self
.
min_p
.
np
[
idx_mapping_np
]
no_min_p
=
np
.
all
(
min_p
==
0.0
)
no_min_p
=
np
.
all
(
min_p
==
0.0
)
min_p
=
self
.
min_p
.
copy_
np_to_gpu
(
min_p
)
if
not
no_min_p
else
None
min_p
=
self
.
min_p
.
copy_
to_uva
(
)
if
not
no_min_p
else
None
rep_penalty
=
self
.
repetition_penalty
.
np
[
idx_mapping_np
]
rep_penalty
=
self
.
repetition_penalty
.
copy_to_uva
()
rep_penalty
=
self
.
repetition_penalty
.
copy_np_to_gpu
(
rep_penalty
)
freq_penalty
=
self
.
frequency_penalty
.
copy_to_uva
()
freq_penalty
=
self
.
frequency_penalty
.
np
[
idx_mapping_np
]
pres_penalty
=
self
.
presence_penalty
.
copy_to_uva
()
freq_penalty
=
self
.
frequency_penalty
.
copy_np_to_gpu
(
freq_penalty
)
pres_penalty
=
self
.
presence_penalty
.
np
[
idx_mapping_np
]
pres_penalty
=
self
.
presence_penalty
.
copy_np_to_gpu
(
pres_penalty
)
seeds
=
self
.
seeds
.
np
[
idx_mapping_np
]
seeds
=
self
.
seeds
.
copy_to_uva
()
seeds
=
self
.
seeds
.
copy_np_to_gpu
(
seeds
)
num_logprobs
=
self
.
num_logprobs
[
idx_mapping_np
]
num_logprobs
=
self
.
num_logprobs
[
idx_mapping_np
]
max_num_logprobs
:
int
|
None
=
int
(
np
.
max
(
num_logprobs
))
max_num_logprobs
:
int
|
None
=
int
(
np
.
max
(
num_logprobs
))
...
@@ -239,6 +235,7 @@ class RequestState:
...
@@ -239,6 +235,7 @@ class RequestState:
max_num_logprobs
=
None
max_num_logprobs
=
None
return
SamplingMetadata
(
return
SamplingMetadata
(
idx_mapping
=
idx_mapping
,
temperature
=
temperature
,
temperature
=
temperature
,
top_p
=
top_p
,
top_p
=
top_p
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -246,12 +243,11 @@ class RequestState:
...
@@ -246,12 +243,11 @@ class RequestState:
repetition_penalty
=
rep_penalty
,
repetition_penalty
=
rep_penalty
,
frequency_penalty
=
freq_penalty
,
frequency_penalty
=
freq_penalty
,
presence_penalty
=
pres_penalty
,
presence_penalty
=
pres_penalty
,
prompt_bin_mask
=
self
.
prompt_bin_mask
,
output_bin_counts
=
self
.
output_bin_counts
,
seeds
=
seeds
,
seeds
=
seeds
,
pos
=
pos
,
pos
=
pos
,
max_num_logprobs
=
max_num_logprobs
,
max_num_logprobs
=
max_num_logprobs
,
idx_mapping
=
idx_mapping
,
prompt_bin_mask
=
self
.
prompt_bin_mask
,
output_bin_counts
=
self
.
output_bin_counts
,
)
)
def
make_lora_inputs
(
def
make_lora_inputs
(
...
@@ -272,42 +268,12 @@ class RequestState:
...
@@ -272,42 +268,12 @@ class RequestState:
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
class
Param
:
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
pin_memory
:
bool
,
):
self
.
buffer
=
CpuGpuBuffer
(
size
,
dtype
=
dtype
,
device
=
device
,
pin_memory
=
pin_memory
,
)
self
.
np
=
np
.
zeros_like
(
self
.
buffer
.
np
)
def
copy_np_to_gpu
(
self
,
x
:
np
.
ndarray
)
->
torch
.
Tensor
:
n
=
x
.
shape
[
0
]
self
.
buffer
.
np
[:
n
]
=
x
return
self
.
buffer
.
copy_to_gpu
(
n
)
@
dataclass
@
dataclass
class
ExtraData
:
class
ExtraData
:
lora_request
:
LoRARequest
|
None
lora_request
:
LoRARequest
|
None
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
in_progress_prompt_logprobs
:
list
[
LogprobsTensors
]
=
field
(
default_factory
=
list
)
class
UvaBuffer
:
def
__init__
(
self
,
*
size
:
int
|
torch
.
SymInt
,
dtype
:
torch
.
dtype
):
assert
is_uva_available
()
self
.
cpu
=
torch
.
zeros
(
*
size
,
dtype
=
dtype
,
device
=
"cpu"
,
pin_memory
=
True
)
self
.
np
=
self
.
cpu
.
numpy
()
self
.
gpu
=
get_cuda_view_from_cpu_tensor
(
self
.
cpu
)
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
def
use_penalty
(
sampling_params
:
SamplingParams
)
->
bool
:
return
(
return
(
sampling_params
.
repetition_penalty
!=
1.0
sampling_params
.
repetition_penalty
!=
1.0
...
...
vllm/v1/worker/gpu/structured_outputs.py
View file @
025a32f9
...
@@ -4,38 +4,65 @@ import numpy as np
...
@@ -4,38 +4,65 @@ import numpy as np
import
torch
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.worker.gpu.input_batch
import
InputBuffers
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.worker.gpu.buffer_utils
import
UvaBufferPool
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
def
apply_grammar_bitmask
(
class
StructuredOutputsWorker
:
logits
:
torch
.
Tensor
,
def
__init__
(
req_ids
:
list
[
str
],
self
,
grammar_req_ids
:
list
[
str
],
max_num_logits
:
int
,
grammar_bitmask
:
np
.
ndarray
,
vocab_size
:
int
,
input_buffers
:
InputBuffers
,
):
)
->
None
:
# NOTE(woosuk): Here, we use UvaBufferPool instead of UvaBackedTensor
input_buffers
.
grammar_bitmask
.
np
[:
grammar_bitmask
.
shape
[
0
]]
=
grammar_bitmask
# to save a unnecessary CPU-to-CPU copy.
input_buffers
.
grammar_bitmask
.
copy_to_gpu
(
grammar_bitmask
.
shape
[
0
])
self
.
logits_indices
=
UvaBufferPool
(
max_num_logits
,
torch
.
int32
)
self
.
grammar_bitmask
=
UvaBufferPool
(
(
max_num_logits
,
cdiv
(
vocab_size
,
32
)),
torch
.
int32
)
batch_size
=
logits
.
shape
[
0
]
def
apply_grammar_bitmask
(
grammar_req_id_to_idx
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
grammar_req_ids
)}
self
,
# logits -> bitmask mapping
logits
:
torch
.
Tensor
,
mapping
=
[
grammar_req_id_to_idx
.
get
(
req_id
,
-
1
)
for
req_id
in
req_ids
]
input_batch
:
InputBatch
,
input_buffers
.
bitmask_indices
.
np
[:
batch_size
]
=
mapping
grammar_req_ids
:
list
[
str
],
input_buffers
.
bitmask_indices
.
copy_to_gpu
(
batch_size
)
grammar_bitmask
:
np
.
ndarray
,
)
->
None
:
if
not
grammar_req_ids
:
return
vocab_size
=
logits
.
shape
[
-
1
]
# Construct bitmask -> logits mapping
BLOCK_SIZE
=
8192
mapping
:
list
[
int
]
=
[]
grid
=
(
batch_size
,
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
))
req_ids
=
input_batch
.
req_ids
_apply_grammar_bitmask_kernel
[
grid
](
cu_num_logits
=
input_batch
.
cu_num_logits_np
.
tolist
()
logits
,
req_id_to_idx
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
req_ids
)}
logits
.
stride
(
0
),
for
grammar_req_id
in
grammar_req_ids
:
input_buffers
.
grammar_bitmask
.
gpu
,
req_idx
=
req_id_to_idx
[
grammar_req_id
]
input_buffers
.
grammar_bitmask
.
gpu
.
stride
(
0
),
logits_start_idx
=
cu_num_logits
[
req_idx
]
input_buffers
.
bitmask_indices
.
gpu
,
logits_end_idx
=
cu_num_logits
[
req_idx
+
1
]
vocab_size
,
mapping
.
extend
(
range
(
logits_start_idx
,
logits_end_idx
))
BLOCK_SIZE
=
BLOCK_SIZE
,
# Copy the mapping.
)
mapping_np
=
np
.
array
(
mapping
,
dtype
=
np
.
int32
)
logits_indices
=
self
.
logits_indices
.
copy_to_uva
(
mapping_np
)
# Copy the bitmask.
bitmask
=
self
.
grammar_bitmask
.
copy_to_uva
(
grammar_bitmask
)
num_masks
=
bitmask
.
shape
[
0
]
assert
num_masks
==
len
(
mapping
)
vocab_size
=
logits
.
shape
[
-
1
]
BLOCK_SIZE
=
8192
grid
=
(
num_masks
,
triton
.
cdiv
(
vocab_size
,
BLOCK_SIZE
))
_apply_grammar_bitmask_kernel
[
grid
](
logits
,
logits
.
stride
(
0
),
logits_indices
,
bitmask
,
bitmask
.
stride
(
0
),
vocab_size
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Adapted from
# Adapted from
...
@@ -44,17 +71,14 @@ def apply_grammar_bitmask(
...
@@ -44,17 +71,14 @@ def apply_grammar_bitmask(
def
_apply_grammar_bitmask_kernel
(
def
_apply_grammar_bitmask_kernel
(
logits_ptr
,
logits_ptr
,
logits_stride
,
logits_stride
,
logits_indices_ptr
,
bitmask_ptr
,
bitmask_ptr
,
bitmask_stride
,
bitmask_stride
,
bitmask_indices_ptr
,
vocab_size
,
vocab_size
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
logits_idx
=
tl
.
program_id
(
0
)
bitmask_idx
=
tl
.
program_id
(
0
)
bitmask_idx
=
tl
.
load
(
bitmask_indices_ptr
+
logits_idx
)
logits_idx
=
tl
.
load
(
logits_indices_ptr
+
bitmask_idx
)
if
bitmask_idx
==
-
1
:
# No bitmask to apply.
return
# Load the bitmask.
# Load the bitmask.
block_id
=
tl
.
program_id
(
1
)
block_id
=
tl
.
program_id
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment