Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
03bdf182
Unverified
Commit
03bdf182
authored
Jan 31, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 31, 2023
Browse files
fix(server): fix seeding on gpu (#42)
parent
4f9ac67c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
8 deletions
+17
-8
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+1
-1
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+1
-1
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+1
-1
server/text_generation/utils.py
server/text_generation/utils.py
+14
-5
No files found.
server/text_generation/models/causal_lm.py
View file @
03bdf182
...
...
@@ -63,7 +63,7 @@ class CausalLMBatch(Batch):
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation/models/galactica.py
View file @
03bdf182
...
...
@@ -102,7 +102,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation/models/seq2seq_lm.py
View file @
03bdf182
...
...
@@ -73,7 +73,7 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation/utils.py
View file @
03bdf182
...
...
@@ -24,8 +24,8 @@ from text_generation.pb import generate_pb2
class
Sampling
:
def
__init__
(
self
,
seed
:
Optional
[
int
]
=
None
):
self
.
generator
=
torch
.
Generator
()
def
__init__
(
self
,
seed
:
Optional
[
int
]
=
None
,
device
:
str
=
"cpu"
):
self
.
generator
=
torch
.
Generator
(
device
)
if
seed
is
not
None
:
self
.
generator
.
manual_seed
(
seed
)
else
:
...
...
@@ -50,7 +50,13 @@ class Greedy:
class
NextTokenChooser
:
def
__init__
(
self
,
temperature
=
1.0
,
top_k
=
None
,
top_p
=
None
,
do_sample
=
False
,
seed
=
None
self
,
temperature
=
1.0
,
top_k
=
None
,
top_p
=
None
,
do_sample
=
False
,
seed
=
None
,
device
=
"cpu"
,
):
warpers
=
LogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
...
...
@@ -68,7 +74,7 @@ class NextTokenChooser:
sampling
=
True
self
.
warpers
=
warpers
self
.
choice
=
Sampling
(
seed
)
if
sampling
else
Greedy
()
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
def
__call__
(
self
,
input_ids
,
scores
):
# Warp logits
...
...
@@ -80,7 +86,9 @@ class NextTokenChooser:
return
next_ids
,
logprobs
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
)
->
"NextTokenChooser"
:
def
from_pb
(
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
,
device
:
torch
.
device
)
->
"NextTokenChooser"
:
# handle protobuf making default values 0
seed
=
pb
.
seed
if
pb
.
HasField
(
"seed"
)
else
None
return
NextTokenChooser
(
...
...
@@ -89,6 +97,7 @@ class NextTokenChooser:
top_p
=
pb
.
top_p
,
do_sample
=
pb
.
do_sample
,
seed
=
seed
,
device
=
str
(
device
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment