Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
de23687d
Unverified
Commit
de23687d
authored
Nov 23, 2023
by
ljss
Committed by
GitHub
Nov 22, 2023
Browse files
Fix repetition penalty aligned with huggingface (#1577)
parent
4cea74c7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
32 deletions
+50
-32
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+47
-29
vllm/sampling_params.py
vllm/sampling_params.py
+3
-3
No files found.
vllm/model_executor/layers/sampler.py
View file @
de23687d
...
@@ -21,7 +21,7 @@ class Sampler(nn.Module):
...
@@ -21,7 +21,7 @@ class Sampler(nn.Module):
1. Discard the hidden states that are not used for sampling (i.e., all
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
2. Compute the logits for the next tokens.
3. Apply presence
and
frequency penalties.
3. Apply presence
,
frequency
and repetition
penalties.
4. Apply temperature scaling.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
6. Sample the next tokens.
...
@@ -50,14 +50,12 @@ class Sampler(nn.Module):
...
@@ -50,14 +50,12 @@ class Sampler(nn.Module):
# Apply logits processors (if any).
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
input_metadata
)
logits
=
_apply_logits_processors
(
logits
,
input_metadata
)
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
output_tokens
=
_get_output_tokens
(
input_metadata
)
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
_get_penalties
(
input_metadata
))
_get_penalties
(
input_metadata
))
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
out
put_
tokens
,
presence_penalties
,
logits
=
_apply_penalties
(
logits
,
in
put_
metadata
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
frequency_penalties
,
repetition_penalties
)
# Apply temperature scaling.
# Apply temperature scaling.
...
@@ -146,7 +144,10 @@ def _get_penalties(
...
@@ -146,7 +144,10 @@ def _get_penalties(
return
presence_penalties
,
frequency_penalties
,
repetition_penalties
return
presence_penalties
,
frequency_penalties
,
repetition_penalties
def
_get_output_tokens
(
input_metadata
:
InputMetadata
)
->
List
[
List
[
int
]]:
def
_get_prompt_and_output_tokens
(
input_metadata
:
InputMetadata
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
...
@@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
...
@@ -155,11 +156,39 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
# NOTE: prompt token positions do not need output tokens to
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
# compute penalties.
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
seq_data
=
input_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
output_tokens
return
prompt_tokens
,
output_tokens
def
_get_bin_counts_and_mask
(
logits
:
torch
.
Tensor
,
tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
num_seqs
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
max_len
=
max
(
len
(
tokens
)
for
tokens
in
tokens
)
padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
max_len
-
len
(
tokens
))
for
tokens
in
tokens
]
tokens_tensor
=
torch
.
tensor
(
padded_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
dtype
=
torch
.
long
,
device
=
logits
.
device
)
bin_counts
.
scatter_add_
(
1
,
tokens_tensor
,
torch
.
ones_like
(
tokens_tensor
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
mask
=
bin_counts
>
0
return
bin_counts
,
mask
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
...
@@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
...
@@ -186,15 +215,13 @@ def _apply_logits_processors(logits: torch.Tensor,
def
_apply_penalties
(
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
out
put_
tokens
:
List
[
List
[
int
]]
,
in
put_
metadata
:
InputMetadata
,
presence_penalties
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
,
vocab_size
=
logits
.
shape
num_seqs
,
vocab_size
=
logits
.
shape
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
if
not
output_tokens
[
i
]:
continue
p
=
presence_penalties
[
i
]
p
=
presence_penalties
[
i
]
f
=
frequency_penalties
[
i
]
f
=
frequency_penalties
[
i
]
r
=
repetition_penalties
[
i
]
r
=
repetition_penalties
[
i
]
...
@@ -206,24 +233,15 @@ def _apply_penalties(
...
@@ -206,24 +233,15 @@ def _apply_penalties(
# Return early if all sequences have zero penalties.
# Return early if all sequences have zero penalties.
return
logits
return
logits
max_output_len
=
max
(
len
(
tokens
)
for
tokens
in
output_tokens
)
prompt_tokens
,
output_tokens
=
(
padded_output_tokens
=
[
_get_prompt_and_output_tokens
(
input_metadata
))
tokens
+
[
vocab_size
]
*
(
max_output_len
-
len
(
tokens
))
assert
len
(
prompt_tokens
)
==
logits
.
shape
[
0
]
for
tokens
in
output_tokens
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
]
output_tokens_tensor
=
torch
.
tensor
(
padded_output_tokens
,
dtype
=
torch
.
long
,
device
=
logits
.
device
)
# Compute the bin counts for the output tokens.
prompt_bin_counts
,
prompt_mask
=
_get_bin_counts_and_mask
(
# vocab_size + 1 for padding.
logits
,
prompt_tokens
,
vocab_size
,
num_seqs
)
bin_counts
=
torch
.
zeros
((
num_seqs
,
vocab_size
+
1
),
output_bin_counts
,
output_mask
=
_get_bin_counts_and_mask
(
dtype
=
torch
.
long
,
logits
,
output_tokens
,
vocab_size
,
num_seqs
)
device
=
logits
.
device
)
bin_counts
.
scatter_add_
(
1
,
output_tokens_tensor
,
torch
.
ones_like
(
output_tokens_tensor
))
bin_counts
=
bin_counts
[:,
:
vocab_size
]
# Remove the padding bin.
mask
=
bin_counts
>
0
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
logits
.
dtype
,
dtype
=
logits
.
dtype
,
...
@@ -236,14 +254,14 @@ def _apply_penalties(
...
@@ -236,14 +254,14 @@ def _apply_penalties(
device
=
logits
.
device
)
device
=
logits
.
device
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
=
repetition_penalties
[:,
None
].
repeat
(
1
,
vocab_size
)
repetition_penalties
[
~
mask
]
=
1.0
repetition_penalties
[
~
(
prompt_mask
|
output_mask
)
]
=
1.0
logits
=
torch
.
where
(
logits
>
0
,
logits
/
repetition_penalties
,
logits
=
torch
.
where
(
logits
>
0
,
logits
/
repetition_penalties
,
logits
*
repetition_penalties
)
logits
*
repetition_penalties
)
# We follow the definition in OpenAI API.
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
bin_counts
logits
-=
frequency_penalties
.
unsqueeze
(
dim
=
1
)
*
output_
bin_counts
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
mask
logits
-=
presence_penalties
.
unsqueeze
(
dim
=
1
)
*
output_
mask
return
logits
return
logits
...
...
vllm/sampling_params.py
View file @
de23687d
...
@@ -42,9 +42,9 @@ class SamplingParams:
...
@@ -42,9 +42,9 @@ class SamplingParams:
model to use new tokens, while values < 0 encourage the model to
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1
encourage the
they appear in the
prompt and the
generated text so far. Values > 1
model to use new tokens, while values < 1 encourage
the model to
encourage the
model to use new tokens, while values < 1 encourage
repeat tokens.
the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
the model more random. Zero means greedy sampling.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment