Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
f1878779
Unverified
Commit
f1878779
authored
Sep 23, 2023
by
Zhuohan Li
Committed by
GitHub
Sep 23, 2023
Browse files
[FIX] Simplify sampler logic (#1156)
parent
947b7941
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
37 deletions
+10
-37
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+10
-37
No files found.
vllm/model_executor/layers/sampler.py
View file @
f1878779
...
@@ -133,37 +133,22 @@ def _get_penalties(
...
@@ -133,37 +133,22 @@ def _get_penalties(
# Collect the presence and frequency penalties.
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
)
:
for
seq_group
in
input_metadata
.
seq_groups
:
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
f
=
sampling_params
.
frequency_penalty
if
i
<
input_metadata
.
num_prompts
:
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
# A prompt input.
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
presence_penalties
.
append
(
p
)
frequency_penalties
.
append
(
f
)
else
:
# A generation token.
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
return
presence_penalties
,
frequency_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
output_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
)
:
for
seq_group
in
input_metadata
.
seq_groups
:
seq_ids
,
_
=
seq_group
seq_ids
,
_
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
for
seq_id
in
seq_ids
:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id
=
seq_ids
[
0
]
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
else
:
# A generation token.
for
seq_id
in
seq_ids
:
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
output_tokens
return
output_tokens
...
@@ -221,7 +206,7 @@ def _apply_penalties(
...
@@ -221,7 +206,7 @@ def _apply_penalties(
def
_get_temperatures
(
input_metadata
:
InputMetadata
)
->
List
[
float
]:
def
_get_temperatures
(
input_metadata
:
InputMetadata
)
->
List
[
float
]:
# Collect the temperatures for the logits.
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
)
:
for
seq_group
in
input_metadata
.
seq_groups
:
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
if
temperature
<
_SAMPLING_EPS
:
if
temperature
<
_SAMPLING_EPS
:
...
@@ -229,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
...
@@ -229,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search).
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
temperature
=
1.0
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
if
i
<
input_metadata
.
num_prompts
:
# A prompt input.
temperatures
.
append
(
temperature
)
else
:
# A generation token.
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
return
temperatures
...
@@ -245,21 +224,15 @@ def _get_top_p_top_k(
...
@@ -245,21 +224,15 @@ def _get_top_p_top_k(
)
->
Tuple
[
List
[
float
],
List
[
int
]]:
)
->
Tuple
[
List
[
float
],
List
[
int
]]:
top_ps
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
)
:
for
seq_group
in
input_metadata
.
seq_groups
:
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
top_p
=
sampling_params
.
top_p
top_p
=
sampling_params
.
top_p
# k should not be greater than the vocab size.
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
# k=-1 means no truncation.
# k=-1 means no truncation.
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
i
<
input_metadata
.
num_prompts
:
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
# A prompt input.
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
top_ps
.
append
(
top_p
)
top_ks
.
append
(
top_k
)
else
:
# A generation token.
top_ps
+=
[
top_p
]
*
len
(
seq_ids
)
top_ks
+=
[
top_k
]
*
len
(
seq_ids
)
return
top_ps
,
top_ks
return
top_ps
,
top_ks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment