Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
941cd42e
Unverified
Commit
941cd42e
authored
Mar 08, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 08, 2023
Browse files
fix(server): fix index out of range for watermarking (#110)
parent
2c5df5d2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
8 additions
and
13 deletions
+8
-13
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+1
-1
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+1
-1
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+1
-1
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+1
-4
server/text_generation_server/utils/watermark.py
server/text_generation_server/utils/watermark.py
+4
-6
No files found.
server/text_generation_server/models/causal_lm.py
View file @
941cd42e
...
@@ -73,7 +73,7 @@ class CausalLMBatch(Batch):
...
@@ -73,7 +73,7 @@ class CausalLMBatch(Batch):
inputs
.
append
(
r
.
inputs
)
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
)
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
...
...
server/text_generation_server/models/galactica.py
View file @
941cd42e
...
@@ -103,7 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -103,7 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
)
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
941cd42e
...
@@ -83,7 +83,7 @@ class Seq2SeqLMBatch(Batch):
...
@@ -83,7 +83,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
)
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
...
...
server/text_generation_server/utils/tokens.py
View file @
941cd42e
...
@@ -36,7 +36,6 @@ class Greedy:
...
@@ -36,7 +36,6 @@ class Greedy:
class
NextTokenChooser
:
class
NextTokenChooser
:
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
,
watermark
=
False
,
watermark
=
False
,
temperature
=
1.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
,
repetition_penalty
=
1.0
,
...
@@ -52,7 +51,7 @@ class NextTokenChooser:
...
@@ -52,7 +51,7 @@ class NextTokenChooser:
sampling
=
do_sample
sampling
=
do_sample
if
watermark
:
if
watermark
:
warpers
.
append
(
WatermarkLogitsProcessor
(
vocab_size
,
device
=
device
))
warpers
.
append
(
WatermarkLogitsProcessor
(
device
=
device
))
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
warpers
.
append
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
warpers
.
append
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
if
temperature
is
not
None
and
temperature
!=
1.0
:
if
temperature
is
not
None
and
temperature
!=
1.0
:
...
@@ -85,11 +84,9 @@ class NextTokenChooser:
...
@@ -85,11 +84,9 @@ class NextTokenChooser:
def
from_pb
(
def
from_pb
(
cls
,
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
,
pb
:
generate_pb2
.
NextTokenChooserParameters
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
"NextTokenChooser"
:
)
->
"NextTokenChooser"
:
return
NextTokenChooser
(
return
NextTokenChooser
(
vocab_size
=
vocab_size
,
watermark
=
pb
.
watermark
,
watermark
=
pb
.
watermark
,
temperature
=
pb
.
temperature
,
temperature
=
pb
.
temperature
,
repetition_penalty
=
pb
.
repetition_penalty
,
repetition_penalty
=
pb
.
repetition_penalty
,
...
...
server/text_generation_server/utils/watermark.py
View file @
941cd42e
...
@@ -25,14 +25,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
...
@@ -25,14 +25,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class
WatermarkLogitsProcessor
(
LogitsProcessor
):
class
WatermarkLogitsProcessor
(
LogitsProcessor
):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
:
int
,
gamma
:
float
=
GAMMA
,
gamma
:
float
=
GAMMA
,
delta
:
float
=
DELTA
,
delta
:
float
=
DELTA
,
hash_key
:
int
=
15485863
,
# just a large prime number to create a rng seed with sufficient bit width
hash_key
:
int
=
15485863
,
# just a large prime number to create a rng seed with sufficient bit width
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
# watermarking parameters
# watermarking parameters
self
.
vocab_size
=
vocab_size
self
.
gamma
=
gamma
self
.
gamma
=
gamma
self
.
delta
=
delta
self
.
delta
=
delta
self
.
rng
=
torch
.
Generator
(
device
=
device
)
self
.
rng
=
torch
.
Generator
(
device
=
device
)
...
@@ -45,13 +43,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -45,13 +43,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
prev_token
=
input_ids
[
-
1
].
item
()
prev_token
=
input_ids
[
-
1
].
item
()
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
)
->
list
[
int
]:
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
,
max_value
:
int
)
->
list
[
int
]:
# seed the rng using the previous tokens/prefix
# seed the rng using the previous tokens/prefix
self
.
_seed_rng
(
input_ids
)
self
.
_seed_rng
(
input_ids
)
greenlist_size
=
int
(
self
.
vocab_siz
e
*
self
.
gamma
)
greenlist_size
=
int
(
max_valu
e
*
self
.
gamma
)
vocab_permutation
=
torch
.
randperm
(
vocab_permutation
=
torch
.
randperm
(
self
.
vocab_siz
e
,
device
=
input_ids
.
device
,
generator
=
self
.
rng
max_valu
e
,
device
=
input_ids
.
device
,
generator
=
self
.
rng
)
)
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
return
greenlist_ids
return
greenlist_ids
...
@@ -76,7 +74,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -76,7 +74,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
assert
len
(
input_ids
)
==
1
assert
len
(
input_ids
)
==
1
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
])
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
]
,
scores
.
shape
[
-
1
]
)
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment