Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1f1b6d6e
Unverified
Commit
1f1b6d6e
authored
Nov 03, 2024
by
Nick Hill
Committed by
GitHub
Nov 03, 2024
Browse files
[V1] Support per-request seed (#9945)
Signed-off-by:
Nick Hill
<
nickhill@us.ibm.com
>
parent
3bb4befe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
48 deletions
+41
-48
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+2
-3
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+10
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+29
-32
No files found.
vllm/v1/sample/metadata.py
View file @
1f1b6d6e
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
Dict
import
torch
import
torch
...
@@ -16,7 +16,6 @@ class SamplingMetadata:
...
@@ -16,7 +16,6 @@ class SamplingMetadata:
no_top_p
:
bool
no_top_p
:
bool
no_top_k
:
bool
no_top_k
:
bool
generators
:
List
[
Optional
[
torch
.
Generator
]]
generators
:
Dict
[
int
,
torch
.
Generator
]
no_generator
:
bool
max_num_logprobs
:
int
max_num_logprobs
:
int
vllm/v1/sample/sampler.py
View file @
1f1b6d6e
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
from
typing
import
List
,
Optional
from
typing
import
Dict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -84,22 +84,21 @@ class Sampler(nn.Module):
...
@@ -84,22 +84,21 @@ class Sampler(nn.Module):
def
random_sample
(
def
random_sample
(
self
,
self
,
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
generators
:
Dict
[
int
,
torch
.
Generator
],
no_generator
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
q
=
torch
.
empty_like
(
probs
)
q
=
torch
.
empty_like
(
probs
)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
# that have their own seeds.
q
.
exponential_
()
if
len
(
generators
)
!=
probs
.
shape
[
0
]:
if
not
no_generator
:
# This might still be done here unnecessarily if there are greedies
assert
len
(
generators
)
==
probs
.
shape
[
0
]
q
.
exponential_
()
if
generators
:
# TODO(woosuk): This can be slow because we handle each request
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
# one by one. Optimize this.
for
i
,
generator
in
enumerate
(
generators
):
for
i
,
generator
in
generators
.
items
():
if
generator
is
not
None
:
q
[
i
].
exponential_
(
generator
=
generator
)
q
[
i
].
exponential_
(
generator
=
generator
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
return
probs
.
div_
(
q
).
argmax
(
dim
=-
1
).
view
(
-
1
)
def
sample
(
def
sample
(
...
@@ -112,13 +111,11 @@ class Sampler(nn.Module):
...
@@ -112,13 +111,11 @@ class Sampler(nn.Module):
if
sampling_metadata
.
all_greedy
:
if
sampling_metadata
.
all_greedy
:
return
self
.
greedy_sample
(
probs
)
return
self
.
greedy_sample
(
probs
)
if
sampling_metadata
.
all_random
:
if
sampling_metadata
.
all_random
:
return
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
,
return
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
)
sampling_metadata
.
no_generator
)
greedy_sampled
=
self
.
greedy_sample
(
probs
)
greedy_sampled
=
self
.
greedy_sample
(
probs
)
random_sampled
=
self
.
random_sample
(
probs
,
random_sampled
=
self
.
random_sample
(
probs
,
sampling_metadata
.
generators
,
sampling_metadata
.
generators
)
sampling_metadata
.
no_generator
)
sampled
=
torch
.
where
(
sampled
=
torch
.
where
(
sampling_metadata
.
temperature
<
_SAMPLING_EPS
,
sampling_metadata
.
temperature
<
_SAMPLING_EPS
,
greedy_sampled
,
greedy_sampled
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1f1b6d6e
...
@@ -128,13 +128,20 @@ class GPUModelRunner:
...
@@ -128,13 +128,20 @@ class GPUModelRunner:
# Add new requests to the cached states.
# Add new requests to the cached states.
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
req_data
.
req_id
req_id
=
req_data
.
req_id
sampling_params
=
req_data
.
sampling_params
if
sampling_params
.
seed
is
not
None
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
else
:
generator
=
None
self
.
requests
[
req_id
]
=
CachedRequestState
(
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_token_ids
=
req_data
.
prompt_token_ids
,
prompt_token_ids
=
req_data
.
prompt_token_ids
,
prompt
=
req_data
.
prompt
,
prompt
=
req_data
.
prompt
,
multi_modal_data
=
req_data
.
multi_modal_data
,
multi_modal_data
=
req_data
.
multi_modal_data
,
sampling_params
=
req_data
.
sampling_params
,
sampling_params
=
sampling_params
,
generator
=
None
,
# TODO
generator
=
generator
,
block_ids
=
req_data
.
block_ids
,
block_ids
=
req_data
.
block_ids
,
num_computed_tokens
=
req_data
.
num_computed_tokens
,
num_computed_tokens
=
req_data
.
num_computed_tokens
,
output_token_ids
=
[],
output_token_ids
=
[],
...
@@ -342,11 +349,9 @@ class GPUModelRunner:
...
@@ -342,11 +349,9 @@ class GPUModelRunner:
else
:
else
:
# Ignore the sampled token from the partial request.
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
# Rewind the generator state as if the token was not sampled.
generator
=
self
.
input_batch
.
generators
[
i
]
generator
=
self
.
input_batch
.
generators
.
get
(
i
)
if
generator
is
not
None
:
if
generator
is
not
None
:
offset
=
generator
.
get_offset
()
generator
.
set_offset
(
generator
.
get_offset
()
-
1
)
generator
=
generator
.
set_offset
(
offset
-
1
)
self
.
input_batch
.
generators
[
i
]
=
generator
if
sampler_output
.
logprob_token_ids
is
None
:
if
sampler_output
.
logprob_token_ids
is
None
:
logprob_token_ids
=
None
logprob_token_ids
=
None
...
@@ -494,8 +499,8 @@ class InputBatch:
...
@@ -494,8 +499,8 @@ class InputBatch:
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_cpu
=
self
.
top_k_cpu_tensor
.
numpy
()
self
.
top_k_reqs
:
Set
[
str
]
=
set
()
self
.
top_k_reqs
:
Set
[
str
]
=
set
()
self
.
generators
:
List
[
Optional
[
torch
.
Generator
]]
=
[
None
# req_index -> generator
]
*
max_num_reqs
self
.
generators
:
Dict
[
int
,
torch
.
Generator
]
=
{}
self
.
num_logprobs
:
Dict
[
str
,
int
]
=
{}
self
.
num_logprobs
:
Dict
[
str
,
int
]
=
{}
self
.
prompt_logprob_reqs
:
Set
[
str
]
=
set
()
self
.
prompt_logprob_reqs
:
Set
[
str
]
=
set
()
...
@@ -509,8 +514,9 @@ class InputBatch:
...
@@ -509,8 +514,9 @@ class InputBatch:
req_index
=
self
.
num_reqs
req_index
=
self
.
num_reqs
assert
req_index
<
self
.
max_num_reqs
assert
req_index
<
self
.
max_num_reqs
self
.
req_ids
[
req_index
]
=
request
.
req_id
req_id
=
request
.
req_id
self
.
req_id_to_index
[
request
.
req_id
]
=
req_index
self
.
req_ids
[
req_index
]
=
req_id
self
.
req_id_to_index
[
req_id
]
=
req_index
# Copy the prompt token ids and output token ids.
# Copy the prompt token ids and output token ids.
num_prompt_tokens
=
len
(
request
.
prompt_token_ids
)
num_prompt_tokens
=
len
(
request
.
prompt_token_ids
)
...
@@ -528,27 +534,24 @@ class InputBatch:
...
@@ -528,27 +534,24 @@ class InputBatch:
sampling_params
=
request
.
sampling_params
sampling_params
=
request
.
sampling_params
self
.
temperature_cpu
[
req_index
]
=
sampling_params
.
temperature
self
.
temperature_cpu
[
req_index
]
=
sampling_params
.
temperature
if
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
:
self
.
greedy_reqs
.
add
(
req_index
)
self
.
greedy_reqs
.
add
(
req_id
)
elif
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM
:
else
:
self
.
random_reqs
.
add
(
req_index
)
self
.
random_reqs
.
add
(
req_id
)
elif
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
# TODO(woosuk): Support per-request random seed.
raise
NotImplementedError
(
"Per-request seed is not supported yet."
)
self
.
top_p_cpu
[
req_index
]
=
sampling_params
.
top_p
self
.
top_p_cpu
[
req_index
]
=
sampling_params
.
top_p
if
sampling_params
.
top_p
<
1
:
if
sampling_params
.
top_p
<
1
:
self
.
top_p_reqs
.
add
(
req_i
ndex
)
self
.
top_p_reqs
.
add
(
req_i
d
)
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
if
sampling_params
.
top_k
>
0
:
if
sampling_params
.
top_k
>
0
:
self
.
top_k_reqs
.
add
(
req_i
ndex
)
self
.
top_k_reqs
.
add
(
req_i
d
)
self
.
generators
[
req_index
]
=
request
.
generator
self
.
generators
[
req_index
]
=
request
.
generator
num_logprobs
=
sampling_params
.
logprobs
num_logprobs
=
sampling_params
.
logprobs
if
num_logprobs
is
not
None
and
num_logprobs
>
0
:
if
num_logprobs
is
not
None
and
num_logprobs
>
0
:
self
.
num_logprobs
[
request
.
req_id
]
=
num_logprobs
self
.
num_logprobs
[
req_id
]
=
num_logprobs
if
sampling_params
.
prompt_logprobs
:
if
sampling_params
.
prompt_logprobs
:
self
.
prompt_logprob_reqs
.
add
(
req_i
ndex
)
self
.
prompt_logprob_reqs
.
add
(
req_i
d
)
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
def
remove_request
(
self
,
req_id
:
str
)
->
Optional
[
int
]:
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
req_index
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
...
@@ -560,7 +563,7 @@ class InputBatch:
...
@@ -560,7 +563,7 @@ class InputBatch:
self
.
random_reqs
.
discard
(
req_id
)
self
.
random_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_p_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
top_k_reqs
.
discard
(
req_id
)
self
.
generators
[
req_index
]
=
None
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
prompt_logprob_reqs
.
discard
(
req_id
)
self
.
prompt_logprob_reqs
.
discard
(
req_id
)
return
req_index
return
req_index
...
@@ -612,7 +615,9 @@ class InputBatch:
...
@@ -612,7 +615,9 @@ class InputBatch:
last_req_index
]
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
generators
[
empty_index
]
=
self
.
generators
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
# Decrement last_req_index since it is now empty.
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
last_req_index
-=
1
...
@@ -636,8 +641,7 @@ class InputBatch:
...
@@ -636,8 +641,7 @@ class InputBatch:
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
top_k
=
self
.
top_k
[:
self
.
num_reqs
],
no_top_p
=
self
.
no_top_p
,
no_top_p
=
self
.
no_top_p
,
no_top_k
=
self
.
no_top_k
,
no_top_k
=
self
.
no_top_k
,
generators
=
self
.
generators
[:
self
.
num_reqs
],
generators
=
self
.
generators
,
no_generator
=
self
.
no_generator
,
max_num_logprobs
=
self
.
max_num_logprobs
,
max_num_logprobs
=
self
.
max_num_logprobs
,
)
)
...
@@ -661,16 +665,9 @@ class InputBatch:
...
@@ -661,16 +665,9 @@ class InputBatch:
def
no_top_k
(
self
)
->
bool
:
def
no_top_k
(
self
)
->
bool
:
return
len
(
self
.
top_k_reqs
)
==
0
return
len
(
self
.
top_k_reqs
)
==
0
@
property
def
no_generator
(
self
)
->
bool
:
return
len
(
self
.
generators
)
==
0
@
property
@
property
def
max_num_logprobs
(
self
)
->
int
:
def
max_num_logprobs
(
self
)
->
int
:
if
self
.
num_logprobs
:
return
max
(
self
.
num_logprobs
.
values
())
if
self
.
num_logprobs
else
0
return
max
(
self
.
num_logprobs
.
values
())
else
:
return
0
@
property
@
property
def
no_logprob
(
self
)
->
bool
:
def
no_logprob
(
self
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment