Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
de23687d
Unverified
Commit
de23687d
authored
Nov 23, 2023
by
ljss
Committed by
GitHub
Nov 22, 2023
Browse files
Fix repetition penalty aligned with huggingface (#1577)
parent
4cea74c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
32 deletions
+50
-32
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+47
-29
vllm/sampling_params.py
vllm/sampling_params.py
+3
-3
No files found.
vllm/model_executor/layers/sampler.py
View file @
de23687d
...
...
@@ -21,7 +21,7 @@ class Sampler(nn.Module):
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence
and
frequency penalties.
3. Apply presence
,
frequency
and repetition
penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
...
...
@@ -50,14 +50,12 @@ class Sampler(nn.Module):
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
input_metadata
)
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
_get_penalties
(
input_metadata
))
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
out
put_
tokens
,
presence_penalties
,
logits
=
_apply_penalties
(
logits
,
in
put_
metadata
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
# Apply temperature scaling.
...
...
@@ -146,7 +144,10 @@ def _get_penalties(
return
presence_penalties
,
frequency_penalties
,
repetition_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
def
_get_prompt_and_output_tokens
(
input_metadata
:
InputMetadata
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
...
...
@@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
output_tokens
return
prompt_tokens
,
output_tokens
def
_get_bin_counts_and_mask
(
logits
:
torch
.
Tensor
,
tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
num_seqs
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
max_len
=
max
(
len
(
tokens
)
for
tokens
in
tokens
)
padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
max_len
-
len
(
tokens
))
for
tokens
in
tokens
]
tokens_tensor
=
torch
.
tensor
(
padded_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
dtype
=
torch
.
long
,
device
=
logits
.
device
)
bin_counts
.
scatter_add_
(
1
,
tokens_tensor
,
torch
.
ones_like
(
tokens_tensor
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
mask
=
bin_counts
>
0
return
bin_counts
,
mask
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
...
...
@@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
out
put_
tokens
:
List
[
List
[
int
]]
,
in
put_
metadata
:
InputMetadata
,
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
)
->
torch
.
Tensor
:
num_seqs
,
vocab_size
=
logits
.
shape
for
i
in
range
(
num_seqs
):
if
not
output_tokens
[
i
]:
continue
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
r
=
repetition_penalties
[
i
]
...
...
@@ -206,24 +233,15 @@ def _apply_penalties(
# Return early if all sequences have zero penalties.
return
logits
max_output_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
padded_output_tokens
=
[
tokens
+
[
vocab_size
]
*
(
max_output_len
-
len
(
tokens
))
for
tokens
in
output_tokens
]
output_tokens_tensor
=
torch
.
tensor
(
padded_output_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
prompt_tokens
,
output_tokens
=
(
_get_prompt_and_output_tokens
(
input_metadata
))
assert
len
(
prompt_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
dtype
=
torch
.
long
,
device
=
logits
.
device
)
bin_counts
.
scatter_add_
(
1
,
output_tokens_tensor
,
torch
.
ones_like
(
output_tokens_tensor
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
# Remove the padding bin.
mask
=
bin_counts
>
0
prompt_bin_counts
,
prompt_mask
=
_get_bin_counts_and_mask
(
logits
,
prompt_tokens
,
vocab_size
,
num_seqs
)
output_bin_counts
,
output_mask
=
_get_bin_counts_and_mask
(
logits
,
output_tokens
,
vocab_size
,
num_seqs
)
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
logits
.
dtype
,
...
...
@@ -236,14 +254,14 @@ def _apply_penalties(
device
=
logits
.
device
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
[
~
mask
]
=
1.0
repetition_penalties
[
~
(
prompt_mask
|
output_mask
)
]
=
1.0
logits
=
torch
.
where
(
logits
>
0
,
logits
/
repetition_penalties
,
logits
*
repetition_penalties
)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
mask
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_
bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_
mask
return
logits
...
...
vllm/sampling_params.py
View file @
de23687d
...
...
@@ -42,9 +42,9 @@ class SamplingParams:
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1
encourage the
model to use new tokens, while values < 1 encourage
the model to
repeat tokens.
they appear in the
prompt and the
generated text so far. Values > 1
encourage the
model to use new tokens, while values < 1 encourage
the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment