Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
09b7c26b
"vscode:/vscode.git/clone" did not exist on "4f59723316bb2df8da19164a568c733be3ced4f0"
Unverified
Commit
09b7c26b
authored
Feb 08, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 08, 2024
Browse files
feat(server): add frequency penalty (#1541)
parent
39af000c
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
11 deletions
+100
-11
server/text_generation_server/utils/logits_process.py
server/text_generation_server/utils/logits_process.py
+56
-0
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+44
-11
No files found.
server/text_generation_server/utils/logits_process.py
View file @
09b7c26b
...
...
@@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return
None
class
FrequencyPenaltyLogitsProcessor
(
LogitsProcessor
):
r
"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def
__init__
(
self
,
penalty
:
float
):
self
.
penalty
=
penalty
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty
,
score
/
self
.
penalty
)
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
class
HeterogeneousFrequencyPenaltyLogitsProcessor
(
LogitsProcessor
):
r
"""
Frequency penalty as defined by OpenAI
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def
__init__
(
self
,
penalty
:
List
[
float
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
self
.
penalty
=
penalty
self
.
penalty_tensor
=
torch
.
tensor
(
penalty
,
dtype
=
dtype
,
device
=
device
).
unsqueeze
(
1
)
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty_tensor
,
score
/
self
.
penalty_tensor
)
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
def
filter
(
self
,
indices
):
self
.
penalty
=
[
self
.
penalty
[
i
]
for
i
in
indices
]
if
any
([
x
!=
0.0
for
x
in
self
.
penalty
]):
self
.
penalty_tensor
=
self
.
penalty_tensor
[
indices
]
return
self
return
None
class
HeterogeneousTemperatureLogitsWarper
:
r
"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
...
...
server/text_generation_server/utils/tokens.py
View file @
09b7c26b
import
re
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.pb.generate_pb2
import
FinishReason
from
text_generation_server.utils.logits_process
import
(
FrequencyPenaltyLogitsProcessor
,
HeterogeneousProcessorWrapper
,
HeterogeneousRepetitionPenaltyLogitsProcessor
,
HeterogeneousFrequencyPenaltyLogitsProcessor
,
HeterogeneousTemperatureLogitsWarper
,
HeterogeneousTopKLogitsWarper
,
HeterogeneousTopPLogitsWarper
,
...
...
@@ -23,6 +25,7 @@ class NextTokenChooser:
watermark
=
False
,
temperature
=
1.0
,
repetition_penalty
=
1.0
,
frequency_penalty
=
0.0
,
top_k
=
None
,
top_p
=
None
,
typical_p
=
None
,
...
...
@@ -35,7 +38,12 @@ class NextTokenChooser:
)
self
.
repetition_processor
=
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
)
if
repetition_penalty
if
repetition_penalty
and
repetition_penalty
!=
1.0
else
None
)
self
.
frequency_processor
=
(
FrequencyPenaltyLogitsProcessor
(
penalty
=
frequency_penalty
)
if
frequency_penalty
and
frequency_penalty
!=
0.0
else
None
)
...
...
@@ -60,6 +68,8 @@ class NextTokenChooser:
scores
=
self
.
watermark_processor
(
input_ids
,
scores
)
if
self
.
repetition_processor
is
not
None
:
scores
=
self
.
repetition_processor
(
input_ids
,
scores
)
if
self
.
frequency_processor
is
not
None
:
scores
=
self
.
frequency_processor
(
input_ids
,
scores
)
if
self
.
static_warper
is
None
:
next_logprob
=
torch
.
log_softmax
(
scores
,
-
1
)
...
...
@@ -80,6 +90,7 @@ class NextTokenChooser:
watermark
=
pb
.
watermark
,
temperature
=
pb
.
temperature
,
repetition_penalty
=
pb
.
repetition_penalty
,
frequency_penalty
=
pb
.
frequency_penalty
,
top_k
=
pb
.
top_k
,
top_p
=
pb
.
top_p
,
typical_p
=
pb
.
typical_p
,
...
...
@@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
watermark
:
List
[
bool
],
temperature
:
List
[
float
],
repetition_penalty
:
List
[
float
],
frequency_penalty
:
List
[
float
],
top_k
:
List
[
int
],
top_p
:
List
[
float
],
typical_p
:
List
[
float
],
...
...
@@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
else
None
)
self
.
frequency_processor
=
(
HeterogeneousFrequencyPenaltyLogitsProcessor
(
frequency_penalty
,
dtype
,
device
)
if
any
([
x
!=
0.0
for
x
in
frequency_penalty
])
else
None
)
if
any
([
x
!=
1.0
for
x
in
temperature
]):
do_sample
=
[
sample
or
x
!=
1.0
for
x
,
sample
in
zip
(
temperature
,
do_sample
)
...
...
@@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
_scores
=
self
.
watermark_processor
(
input_ids
,
_scores
)
if
self
.
repetition_processor
is
not
None
:
_scores
=
self
.
repetition_processor
(
input_ids
,
_scores
)
if
self
.
frequency_processor
is
not
None
:
_scores
=
self
.
frequency_processor
(
input_ids
,
_scores
)
for
warper
in
self
.
warpers
:
_scores
=
warper
(
input_ids
,
_scores
)
...
...
@@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
next_logprobs
=
torch
.
gather
(
logprobs
,
1
,
next_ids
.
view
(
-
1
,
1
)).
view
(
-
1
)
if
speculate
>
0
:
if
speculative_scores
is
not
None
:
# Medusa provided some scores
...
...
@@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
if
self
.
repetition_processor
is
not
None
:
self
.
repetition_processor
=
self
.
repetition_processor
.
filter
(
indices
)
if
self
.
frequency_processor
is
not
None
:
self
.
frequency_processor
=
self
.
frequency_processor
.
filter
(
indices
)
filtered_warpers
=
[]
for
warper
in
self
.
warpers
:
filtered_warper
=
warper
.
filter
(
indices
)
...
...
@@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
watermark
=
[
pb_
.
watermark
for
pb_
in
pb
],
temperature
=
[
pb_
.
temperature
for
pb_
in
pb
],
repetition_penalty
=
[
pb_
.
repetition_penalty
for
pb_
in
pb
],
frequency_penalty
=
[
pb_
.
frequency_penalty
for
pb_
in
pb
],
top_k
=
[
pb_
.
top_k
for
pb_
in
pb
],
top_p
=
[
pb_
.
top_p
for
pb_
in
pb
],
typical_p
=
[
pb_
.
typical_p
for
pb_
in
pb
],
...
...
@@ -438,7 +463,10 @@ class HeterogeneousSampling:
def
batch_top_tokens
(
top_n_tokens
:
List
[
int
],
top_n_tokens_tensor
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
accepted_ids
:
torch
.
Tensor
top_n_tokens
:
List
[
int
],
top_n_tokens_tensor
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
accepted_ids
:
torch
.
Tensor
,
)
->
Tuple
[
List
[
List
[
List
[
int
]]],
List
[
List
[
List
[
float
]]]]:
"""Find the top n most likely tokens for a batch of generations.
...
...
@@ -450,12 +478,15 @@ def batch_top_tokens(
if
max_top_n
==
0
:
return
[[[]]]
*
len
(
top_n_tokens
),
[[[]]]
*
len
(
top_n_tokens
)
batch_size
=
accepted_ids
.
shape
[
0
]
speculate_size
=
logprobs
.
shape
[
0
]
//
batch_size
top_n_tokens_tensor
=
top_n_tokens_tensor
.
repeat_interleave
(
speculate_size
)
# Ensure top_n doesn't exceed vocab size
top_n_tokens
=
[
min
(
tok
,
logprobs
.
size
(
-
1
))
for
tok
in
top_n_tokens
for
_
in
range
(
speculate_size
)]
top_n_tokens
=
[
min
(
tok
,
logprobs
.
size
(
-
1
))
for
tok
in
top_n_tokens
for
_
in
range
(
speculate_size
)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
...
...
@@ -484,10 +515,10 @@ def batch_top_tokens(
for
i
,
n_accepted_ids
in
enumerate
(
accepted_ids_list
):
start
=
speculate_size
*
i
stop
=
speculate_size
*
(
i
+
1
)
_top_indices
=
top_indices
[
start
:
stop
]
_top_values
=
top_values
[
start
:
stop
]
_top_n_ishes
=
top_n_ishes
[
start
:
stop
]
_top_n_tokens
=
top_n_tokens
[
start
:
stop
]
_top_indices
=
top_indices
[
start
:
stop
]
_top_values
=
top_values
[
start
:
stop
]
_top_n_ishes
=
top_n_ishes
[
start
:
stop
]
_top_n_tokens
=
top_n_tokens
[
start
:
stop
]
_top_indices
=
_top_indices
[:
n_accepted_ids
]
_top_values
=
_top_values
[:
n_accepted_ids
]
...
...
@@ -497,7 +528,9 @@ def batch_top_tokens(
row_top_token_ids
=
[]
row_top_token_logprobs
=
[]
for
idxs
,
vals
,
n
,
req_n
in
zip
(
_top_indices
,
_top_values
,
_top_n_ishes
,
_top_n_tokens
):
for
idxs
,
vals
,
n
,
req_n
in
zip
(
_top_indices
,
_top_values
,
_top_n_ishes
,
_top_n_tokens
):
indices
=
idxs
[:
n
]
if
req_n
>
0
else
[]
values
=
vals
[:
n
]
if
req_n
>
0
else
[]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment