Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
69be658b
Unverified
Commit
69be658b
authored
Oct 30, 2023
by
ljss
Committed by
GitHub
Oct 29, 2023
Browse files
Support repetition_penalty (#1424)
parent
beac8dd4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
7 deletions
+35
-7
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+25
-7
vllm/sampling_params.py
vllm/sampling_params.py
+10
-0
No files found.
vllm/model_executor/layers/sampler.py
View file @
69be658b
...
@@ -50,12 +50,13 @@ class Sampler(nn.Module):
...
@@ -50,12 +50,13 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
=
_get
_penalties
(
presence_penalties
,
frequency_penalties
,
repetition
_penalties
=
(
input_metadata
)
_get_penalties
(
input_metadata
)
)
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
logits
=
_apply_penalties
(
logits
,
output_tokens
,
presence_penalties
,
frequency_penalties
)
frequency_penalties
,
repetition_penalties
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input_metadata
)
temperatures
=
_get_temperatures
(
input_metadata
)
...
@@ -134,14 +135,17 @@ def _prune_hidden_states(
...
@@ -134,14 +135,17 @@ def _prune_hidden_states(
def
_get_penalties
(
def
_get_penalties
(
input_metadata
:
InputMetadata
)
->
Tuple
[
List
[
float
],
List
[
float
]]:
input_metadata
:
InputMetadata
)
->
Tuple
[
List
[
float
],
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
if
(
i
<
input_metadata
.
num_prompts
if
(
i
<
input_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: We do not apply presence and frequency penalties for the
# NOTE: We do not apply presence and frequency penalties for the
...
@@ -149,9 +153,11 @@ def _get_penalties(
...
@@ -149,9 +153,11 @@ def _get_penalties(
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
presence_penalties
+=
[
p
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
return
presence_penalties
,
frequency_penalties
,
repetition_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
...
@@ -175,6 +181,7 @@ def _apply_penalties(
...
@@ -175,6 +181,7 @@ def _apply_penalties(
output_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
presence_penalties
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
,
vocab_size
=
logits
.
shape
num_seqs
,
vocab_size
=
logits
.
shape
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
...
@@ -182,7 +189,9 @@ def _apply_penalties(
...
@@ -182,7 +189,9 @@ def _apply_penalties(
continue
continue
p
=
presence_penalties
[
i
]
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
f
=
frequency_penalties
[
i
]
if
abs
(
p
)
<
_SAMPLING_EPS
and
abs
(
f
)
<
_SAMPLING_EPS
:
r
=
repetition_penalties
[
i
]
if
abs
(
p
)
<
_SAMPLING_EPS
and
abs
(
f
)
<
_SAMPLING_EPS
and
abs
(
r
-
1.0
)
<
_SAMPLING_EPS
:
continue
continue
break
break
else
:
else
:
...
@@ -206,7 +215,11 @@ def _apply_penalties(
...
@@ -206,7 +215,11 @@ def _apply_penalties(
bin_counts
.
scatter_add_
(
1
,
output_tokens_tensor
,
bin_counts
.
scatter_add_
(
1
,
output_tokens_tensor
,
torch
.
ones_like
(
output_tokens_tensor
))
torch
.
ones_like
(
output_tokens_tensor
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
# Remove the padding bin.
bin_counts
=
bin_counts
[:,
:
vocab_size
]
# Remove the padding bin.
mask
=
bin_counts
>
0
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
logits
.
dtype
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
device
=
logits
.
device
)
...
@@ -214,10 +227,15 @@ def _apply_penalties(
...
@@ -214,10 +227,15 @@ def _apply_penalties(
dtype
=
logits
.
dtype
,
dtype
=
logits
.
dtype
,
device
=
logits
.
device
)
device
=
logits
.
device
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
[
~
mask
]
=
1.0
logits
=
torch
.
where
(
logits
>
0
,
logits
/
repetition_penalties
,
logits
*
repetition_penalties
)
# We follow the definition in OpenAI API.
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
(
bin_counts
>
0
)
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
mask
return
logits
return
logits
...
...
vllm/sampling_params.py
View file @
69be658b
...
@@ -34,6 +34,10 @@ class SamplingParams:
...
@@ -34,6 +34,10 @@ class SamplingParams:
frequency in the generated text so far. Values > 0 encourage the
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1 encourage the
model to use new tokens, while values < 1 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
the model more random. Zero means greedy sampling.
...
@@ -75,6 +79,7 @@ class SamplingParams:
...
@@ -75,6 +79,7 @@ class SamplingParams:
best_of
:
Optional
[
int
]
=
None
,
best_of
:
Optional
[
int
]
=
None
,
presence_penalty
:
float
=
0.0
,
presence_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
repetition_penalty
:
float
=
1.0
,
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
...
@@ -93,6 +98,7 @@ class SamplingParams:
...
@@ -93,6 +98,7 @@ class SamplingParams:
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
presence_penalty
=
presence_penalty
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
repetition_penalty
=
repetition_penalty
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
top_k
=
top_k
...
@@ -136,6 +142,9 @@ class SamplingParams:
...
@@ -136,6 +142,9 @@ class SamplingParams:
if
not
-
2.0
<=
self
.
frequency_penalty
<=
2.0
:
if
not
-
2.0
<=
self
.
frequency_penalty
<=
2.0
:
raise
ValueError
(
"frequency_penalty must be in [-2, 2], got "
raise
ValueError
(
"frequency_penalty must be in [-2, 2], got "
f
"
{
self
.
frequency_penalty
}
."
)
f
"
{
self
.
frequency_penalty
}
."
)
if
not
0.0
<
self
.
repetition_penalty
<=
2.0
:
raise
ValueError
(
"repetition_penalty must be in (0, 2], got "
f
"
{
self
.
repetition_penalty
}
."
)
if
self
.
temperature
<
0.0
:
if
self
.
temperature
<
0.0
:
raise
ValueError
(
raise
ValueError
(
f
"temperature must be non-negative, got
{
self
.
temperature
}
."
)
f
"temperature must be non-negative, got
{
self
.
temperature
}
."
)
...
@@ -201,6 +210,7 @@ class SamplingParams:
...
@@ -201,6 +210,7 @@ class SamplingParams:
f
"best_of=
{
self
.
best_of
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"repetition_penalty=
{
self
.
repetition_penalty
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"temperature=
{
self
.
temperature
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_p=
{
self
.
top_p
}
, "
f
"top_k=
{
self
.
top_k
}
, "
f
"top_k=
{
self
.
top_k
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment