Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a7347d9a
"...text-generation-inference.git" did not exist on "a072660bf51022f9e1d59b64efc5954a0e1eee45"
Unverified
Commit
a7347d9a
authored
Dec 17, 2023
by
Antoni Baum
Committed by
GitHub
Dec 17, 2023
Browse files
Make sampler less blocking (#1889)
parent
f8c688d7
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
310 additions
and
198 deletions
+310
-198
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+123
-198
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+187
-0
No files found.
vllm/model_executor/layers/sampler.py
View file @
a7347d9a
This diff is collapsed.
Click to expand it.
vllm/model_executor/sampling_metadata.py
View file @
a7347d9a
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
in_wsl
_SAMPLING_EPS
=
1e-5
class
SamplingMetadata
:
class
SamplingMetadata
:
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
...
@@ -41,3 +45,186 @@ class SamplingMetadata:
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
@
dataclass
class
SamplingTensors
:
"""Tensors for sampling."""
temperatures
:
torch
.
Tensor
top_ps
:
torch
.
Tensor
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
@
classmethod
def
from_sampling_metadata
(
cls
,
sampling_metadata
:
"SamplingMetadata"
,
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
top_ks
:
List
[
int
]
=
[]
temperatures
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
min_ps
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
do_penalties
=
False
do_top_p_top_k
=
False
do_min_p
=
False
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
temperature
<
_SAMPLING_EPS
:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
if
not
do_top_p_top_k
and
(
top_p
<
1.0
-
_SAMPLING_EPS
or
top_k
!=
vocab_size
):
do_top_p_top_k
=
True
if
not
do_min_p
and
min_p
>
_SAMPLING_EPS
:
do_min_p
=
True
if
not
do_penalties
and
(
abs
(
p
)
>=
_SAMPLING_EPS
or
abs
(
f
)
>=
_SAMPLING_EPS
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get their logprobs
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
min_ps
+=
[
min_p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
output_tokens
,
vocab_size
,
device
,
dtype
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
def
from_lists
(
cls
,
temperatures
:
List
[
float
],
top_ps
:
List
[
float
],
top_ks
:
List
[
int
],
min_ps
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
prompt_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# pinned memory.
pin_memory
=
not
in_wsl
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
for
tokens
in
prompt_tokens
]
output_max_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
output_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
output_max_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
temperatures_t
=
torch
.
tensor
(
temperatures
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ps_t
=
torch
.
tensor
(
top_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
min_ps_t
=
torch
.
tensor
(
min_ps
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
presence_penalties_t
=
torch
.
tensor
(
presence_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
frequency_penalties_t
=
torch
.
tensor
(
frequency_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
repetition_penalties_t
=
torch
.
tensor
(
repetition_penalties
,
device
=
"cpu"
,
dtype
=
dtype
,
pin_memory
=
pin_memory
,
)
top_ks_t
=
torch
.
tensor
(
top_ks
,
device
=
"cpu"
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
)
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
output_tensor
=
torch
.
tensor
(
output_padded_tokens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ks
=
top_ks_t
.
to
(
device
=
device
,
non_blocking
=
True
),
min_ps
=
min_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
presence_penalties
=
presence_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
frequency_penalties
=
frequency_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
repetition_penalties
=
repetition_penalties_t
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment