Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
a6c18c39
Unverified
Commit
a6c18c39
authored
May 10, 2023
by
OlivierDehaene
Committed by
GitHub
May 10, 2023
Browse files
feat(server): use cuda graph in logits warping (#302)
parent
35ab6cfc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
33 deletions
+94
-33
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+94
-33
No files found.
server/text_generation_server/utils/tokens.py
View file @
a6c18c39
import
re
import
torch
from
functools
import
lru_cache
from
transformers
import
(
LogitsProcessorList
,
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
...
...
@@ -25,8 +25,10 @@ class Sampling:
def
__call__
(
self
,
logits
):
probs
=
torch
.
nn
.
functional
.
softmax
(
logits
,
-
1
)
next_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
,
generator
=
self
.
generator
)
return
next_tokens
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1
,
generator
=
self
.
generator
).
div_
(
probs
)
return
q
.
argmax
()
class
Greedy
:
...
...
@@ -34,6 +36,63 @@ class Greedy:
return
logits
.
argmax
()
class
StaticWarper
:
def
__init__
(
self
,
temperature
=
1.0
,
top_k
=
None
,
top_p
=
None
,
typical_p
=
None
,
):
self
.
warpers
=
[]
if
temperature
is
not
None
and
temperature
!=
1.0
:
temperature
=
float
(
temperature
)
self
.
warpers
.
append
(
TemperatureLogitsWarper
(
temperature
))
if
top_k
is
not
None
and
top_k
!=
0
:
self
.
warpers
.
append
(
TopKLogitsWarper
(
top_k
=
top_k
))
if
top_p
is
not
None
and
top_p
<
1.0
:
self
.
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
if
typical_p
is
not
None
and
typical_p
<
1.0
:
self
.
warpers
.
append
(
TypicalLogitsWarper
(
mass
=
typical_p
))
self
.
cuda_graph
=
None
self
.
static_scores
=
None
self
.
static_warped_scores
=
None
self
.
static_next_logprob
=
None
def
__call__
(
self
,
scores
):
if
self
.
cuda_graph
is
None
:
self
.
static_scores
=
scores
self
.
cuda_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
cuda_graph
):
for
warper
in
self
.
warpers
:
self
.
static_warped_scores
=
warper
(
None
,
self
.
static_scores
)
# Compute logprobs
self
.
static_next_logprob
=
torch
.
log_softmax
(
self
.
static_warped_scores
,
-
1
)
self
.
static_scores
.
copy_
(
scores
)
self
.
cuda_graph
.
replay
()
return
self
.
static_warped_scores
,
self
.
static_next_logprob
@
lru_cache
(
10
)
def
static_warper
(
temperature
:
Optional
[
float
],
top_k
:
Optional
[
int
],
top_p
:
Optional
[
float
],
typical_p
:
Optional
[
float
],
)
->
StaticWarper
:
return
StaticWarper
(
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
typical_p
=
typical_p
)
class
NextTokenChooser
:
def
__init__
(
self
,
...
...
@@ -47,43 +106,45 @@ class NextTokenChooser:
seed
=
0
,
device
=
"cpu"
,
):
warpers
=
LogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
sampling
=
do_sample
if
watermark
:
warpers
.
append
(
WatermarkLogitsProcessor
(
device
=
device
))
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
warpers
.
append
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
if
temperature
is
not
None
and
temperature
!=
1.0
:
temperature
=
float
(
temperature
)
warpers
.
append
(
TemperatureLogitsWarper
(
temperature
))
sampling
=
True
if
top_k
is
not
None
and
top_k
!=
0
:
warpers
.
append
(
TopKLogitsWarper
(
top_k
=
top_k
))
sampling
=
True
if
top_p
is
not
None
and
top_p
<
1.0
:
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
sampling
=
True
if
typical_p
is
not
None
and
typical_p
<
1.0
:
warpers
.
append
(
TypicalLogitsWarper
(
mass
=
typical_p
))
sampling
=
True
self
.
watermark_processor
=
(
WatermarkLogitsProcessor
(
device
=
device
)
if
watermark
else
None
)
self
.
repetition_processor
=
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
)
if
repetition_penalty
else
None
)
self
.
warpers
=
warpers
has_warpers
=
(
(
temperature
is
not
None
and
temperature
!=
1.0
)
or
(
top_k
is
not
None
and
top_k
!=
0
)
or
(
top_p
is
not
None
and
top_p
<
1.0
)
or
(
typical_p
is
not
None
and
typical_p
<
1.0
)
)
if
has_warpers
:
self
.
static_warper
=
static_warper
(
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
typical_p
=
typical_p
)
else
:
self
.
static_warper
=
None
sampling
=
do_sample
or
has_warpers
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
def
__call__
(
self
,
input_ids
,
scores
):
# Warp logits
scores
=
self
.
warpers
(
input_ids
,
scores
)
if
self
.
watermark_processor
:
scores
=
self
.
watermark_processor
(
input_ids
,
scores
)
if
self
.
repetition_processor
:
scores
=
self
.
repetition_processor
(
input_ids
,
scores
)
# Compute logprobs
logprobs
=
torch
.
log_softmax
(
scores
,
-
1
)
if
self
.
static_warper
is
None
:
next_logprob
=
torch
.
log_softmax
(
scores
,
-
1
)
else
:
scores
,
next_logprob
=
self
.
static_warper
(
scores
)
# Choose tokens
next_id
=
self
.
choice
(
scores
[
-
1
])
next_id
=
self
.
choice
(
scores
[
-
1
]).
view
(
1
,
1
)
return
next_id
.
view
(
1
,
1
),
logprob
s
return
next_id
,
next_
logprob
@
classmethod
def
from_pb
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment