Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
9b205d33
Unverified
Commit
9b205d33
authored
Mar 06, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 06, 2023
Browse files
fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)
parent
1c19b093
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
29 deletions
+45
-29
launcher/tests/mt0_base.json
launcher/tests/mt0_base.json
+11
-11
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+1
-1
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+1
-3
server/text_generation/models/model.py
server/text_generation/models/model.py
+18
-0
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+4
-4
server/text_generation/utils/watermark.py
server/text_generation/utils/watermark.py
+10
-10
No files found.
launcher/tests/mt0_base.json
View file @
9b205d33
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
"tokens"
:
[
"tokens"
:
[
{
{
"id"
:
259
,
"id"
:
259
,
"text"
:
""
,
"text"
:
"
"
,
"logprob"
:
-1.3656927
,
"logprob"
:
-1.3656927
,
"special"
:
false
"special"
:
false
},
},
...
@@ -32,13 +32,13 @@
...
@@ -32,13 +32,13 @@
},
},
{
{
"id"
:
287
,
"id"
:
287
,
"text"
:
"the"
,
"text"
:
"
the"
,
"logprob"
:
-1.2102449
,
"logprob"
:
-1.2102449
,
"special"
:
false
"special"
:
false
},
},
{
{
"id"
:
259
,
"id"
:
259
,
"text"
:
""
,
"text"
:
"
"
,
"logprob"
:
-1.6057279
,
"logprob"
:
-1.6057279
,
"special"
:
false
"special"
:
false
},
},
...
@@ -50,19 +50,19 @@
...
@@ -50,19 +50,19 @@
},
},
{
{
"id"
:
304
,
"id"
:
304
,
"text"
:
"of"
,
"text"
:
"
of"
,
"logprob"
:
-0.5270343
,
"logprob"
:
-0.5270343
,
"special"
:
false
"special"
:
false
},
},
{
{
"id"
:
287
,
"id"
:
287
,
"text"
:
"the"
,
"text"
:
"
the"
,
"logprob"
:
-0.62522805
,
"logprob"
:
-0.62522805
,
"special"
:
false
"special"
:
false
},
},
{
{
"id"
:
259
,
"id"
:
259
,
"text"
:
""
,
"text"
:
"
"
,
"logprob"
:
-1.4069618
,
"logprob"
:
-1.4069618
,
"special"
:
false
"special"
:
false
},
},
...
@@ -74,19 +74,19 @@
...
@@ -74,19 +74,19 @@
},
},
{
{
"id"
:
304
,
"id"
:
304
,
"text"
:
"of"
,
"text"
:
"
of"
,
"logprob"
:
-1.3172221
,
"logprob"
:
-1.3172221
,
"special"
:
false
"special"
:
false
},
},
{
{
"id"
:
287
,
"id"
:
287
,
"text"
:
"the"
,
"text"
:
"
the"
,
"logprob"
:
-0.3501925
,
"logprob"
:
-0.3501925
,
"special"
:
false
"special"
:
false
},
},
{
{
"id"
:
259
,
"id"
:
259
,
"text"
:
""
,
"text"
:
"
"
,
"logprob"
:
-0.7219573
,
"logprob"
:
-0.7219573
,
"special"
:
false
"special"
:
false
},
},
...
@@ -104,7 +104,7 @@
...
@@ -104,7 +104,7 @@
},
},
{
{
"id"
:
259
,
"id"
:
259
,
"text"
:
""
,
"text"
:
"
"
,
"logprob"
:
-0.32933083
,
"logprob"
:
-0.32933083
,
"special"
:
false
"special"
:
false
},
},
...
@@ -116,7 +116,7 @@
...
@@ -116,7 +116,7 @@
},
},
{
{
"id"
:
2978
,
"id"
:
2978
,
"text"
:
"test"
,
"text"
:
"
test"
,
"logprob"
:
-1.5846587
,
"logprob"
:
-1.5846587
,
"special"
:
false
"special"
:
false
},
},
...
...
server/tests/models/test_seq2seq_lm.py
View file @
9b205d33
...
@@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
...
@@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
259
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
259
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
""
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
"
"
for
generation
in
generations
])
assert
generations
[
0
].
request_id
==
0
assert
generations
[
0
].
request_id
==
0
...
...
server/text_generation/models/causal_lm.py
View file @
9b205d33
...
@@ -385,10 +385,8 @@ class CausalLM(Model):
...
@@ -385,10 +385,8 @@ class CausalLM(Model):
# Generated token
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_text
=
self
.
tokenizer
.
decode
(
next_token_text
=
self
.
decode
_token
(
next_token_id_squeezed
,
next_token_id_squeezed
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
)
# Evaluate stopping criteria
# Evaluate stopping criteria
...
...
server/text_generation/models/model.py
View file @
9b205d33
...
@@ -15,6 +15,15 @@ class Model(ABC):
...
@@ -15,6 +15,15 @@ class Model(ABC):
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
device
=
device
self
.
device
=
device
# see `decode_token` method
self
.
tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
"<decode-token>"
]}
)
self
.
special_decode_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"<decode-token>"
)
self
.
special_decode_token_length
=
len
(
"<decode-token>"
)
@
property
@
property
@
abstractmethod
@
abstractmethod
def
batch_type
(
self
)
->
Type
[
B
]:
def
batch_type
(
self
)
->
Type
[
B
]:
...
@@ -23,3 +32,12 @@ class Model(ABC):
...
@@ -23,3 +32,12 @@ class Model(ABC):
@
abstractmethod
@
abstractmethod
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
raise
NotImplementedError
raise
NotImplementedError
def
decode_token
(
self
,
token_id
:
int
)
->
str
:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# append token to special decode token and decode both
result
=
self
.
tokenizer
.
decode
(
[
self
.
special_decode_token_id
,
token_id
],
skip_special_tokens
=
False
)
# slice to remove special decode token
return
result
[
self
.
special_decode_token_length
:]
server/text_generation/models/seq2seq_lm.py
View file @
9b205d33
...
@@ -342,7 +342,9 @@ class Seq2SeqLM(Model):
...
@@ -342,7 +342,9 @@ class Seq2SeqLM(Model):
return
Seq2SeqLMBatch
return
Seq2SeqLMBatch
def
decode
(
self
,
decoder_ids
:
List
[
int
])
->
str
:
def
decode
(
self
,
decoder_ids
:
List
[
int
])
->
str
:
return
self
.
tokenizer
.
decode
(
decoder_ids
,
skip_special_tokens
=
True
)
return
self
.
tokenizer
.
decode
(
decoder_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -457,10 +459,8 @@ class Seq2SeqLM(Model):
...
@@ -457,10 +459,8 @@ class Seq2SeqLM(Model):
# Generated token
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_text
=
self
.
tokenizer
.
decode
(
next_token_text
=
self
.
decode
_token
(
next_token_id_squeezed
,
next_token_id_squeezed
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
)
# Evaluate stopping criteria
# Evaluate stopping criteria
...
...
server/text_generation/utils/watermark.py
View file @
9b205d33
...
@@ -24,12 +24,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
...
@@ -24,12 +24,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class
WatermarkLogitsProcessor
(
LogitsProcessor
):
class
WatermarkLogitsProcessor
(
LogitsProcessor
):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
:
int
,
vocab_size
:
int
,
gamma
:
float
=
GAMMA
,
gamma
:
float
=
GAMMA
,
delta
:
float
=
DELTA
,
delta
:
float
=
DELTA
,
hash_key
:
int
=
15485863
,
# just a large prime number to create a rng seed with sufficient bit width
hash_key
:
int
=
15485863
,
# just a large prime number to create a rng seed with sufficient bit width
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
# watermarking parameters
# watermarking parameters
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
@@ -40,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -40,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
def
_seed_rng
(
self
,
input_ids
:
torch
.
LongTensor
)
->
None
:
def
_seed_rng
(
self
,
input_ids
:
torch
.
LongTensor
)
->
None
:
assert
(
assert
(
input_ids
.
shape
[
-
1
]
>=
1
input_ids
.
shape
[
-
1
]
>=
1
),
"requires at least a 1 token prefix sequence to seed rng"
),
"requires at least a 1 token prefix sequence to seed rng"
prev_token
=
input_ids
[
-
1
].
item
()
prev_token
=
input_ids
[
-
1
].
item
()
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
...
@@ -58,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -58,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
@
staticmethod
@
staticmethod
def
_calc_greenlist_mask
(
def
_calc_greenlist_mask
(
scores
:
torch
.
FloatTensor
,
greenlist_token_ids
scores
:
torch
.
FloatTensor
,
greenlist_token_ids
)
->
torch
.
BoolTensor
:
)
->
torch
.
BoolTensor
:
green_tokens_mask
=
torch
.
zeros_like
(
scores
)
green_tokens_mask
=
torch
.
zeros_like
(
scores
)
green_tokens_mask
[
-
1
,
greenlist_token_ids
]
=
1
green_tokens_mask
[
-
1
,
greenlist_token_ids
]
=
1
...
@@ -67,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
...
@@ -67,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
@
staticmethod
@
staticmethod
def
_bias_greenlist_logits
(
def
_bias_greenlist_logits
(
scores
:
torch
.
Tensor
,
greenlist_mask
:
torch
.
Tensor
,
greenlist_bias
:
float
scores
:
torch
.
Tensor
,
greenlist_mask
:
torch
.
Tensor
,
greenlist_bias
:
float
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
scores
[
greenlist_mask
]
=
scores
[
greenlist_mask
]
+
greenlist_bias
scores
[
greenlist_mask
]
=
scores
[
greenlist_mask
]
+
greenlist_bias
return
scores
return
scores
def
__call__
(
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
assert
len
(
input_ids
)
==
1
assert
len
(
input_ids
)
==
1
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
])
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment