Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a7347d9a
Unverified
Commit
a7347d9a
authored
Dec 17, 2023
by
Antoni Baum
Committed by
GitHub
Dec 17, 2023
Browse files
Make sampler less blocking (#1889)
parent
f8c688d7
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
310 additions
and
198 deletions
+310
-198
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+123
-198
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+187
-0
No files found.
vllm/model_executor/layers/sampler.py
View file @
a7347d9a
This diff is collapsed.
Click to expand it.
vllm/model_executor/sampling_metadata.py
View file @
a7347d9a
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
in_wsl
_SAMPLING_EPS
=
1e-5
class
SamplingMetadata
:
class
SamplingMetadata
:
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
@
dataclass
class
SamplingTensors
:
"""Tensors for sampling."""
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
@
classmethod
def
from_sampling_metadata
(
cls
,
sampling_metadata
:
"SamplingMetadata"
,
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
top_ks
:
List
[
int
]
=
[]
temperatures
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
min_ps
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
do_penalties
=
False
do_top_p_top_k
=
False
do_min_p
=
False
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
not
do_top_p_top_k
and
(
top_p
<
1.0
-
_SAMPLING_EPS
or
top_k
!=
vocab_size
):
do_top_p_top_k
=
True
if
not
do_min_p
and
min_p
>
_SAMPLING_EPS
:
do_min_p
=
True
if
not
do_penalties
and
(
abs
(
p
)
>=
_SAMPLING_EPS
or
abs
(
f
)
>=
_SAMPLING_EPS
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get their logprobs
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
prompt_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# pinned memory.
pin_memory
=
not
in_wsl
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
]
output_max_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
temperatures_t
=
torch
.
tensor
(
temperatures
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ps_t
=
torch
.
tensor
(
top_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
min_ps_t
=
torch
.
tensor
(
min_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
presence_penalties_t
=
torch
.
tensor
(
presence_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
frequency_penalties_t
=
torch
.
tensor
(
frequency_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
repetition_penalties_t
=
torch
.
tensor
(
repetition_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ks_t
=
torch
.
tensor
(
top_ks
,
device
=
"cpu"
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
)
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
output_tensor
=
torch
.
tensor
(
output_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ks
=
top_ks_t
.
to
(
device
=
device
,
non_blocking
=
True
),
min_ps
=
min_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
presence_penalties
=
presence_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
frequency_penalties
=
frequency_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment