Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
9b8ea6a6
Unverified
Commit
9b8ea6a6
authored
Mar 02, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 02, 2023
Browse files
feat(server): add logits watermark (#90)
parent
f874c478
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
141 additions
and
7 deletions
+141
-7
launcher/src/main.rs
launcher/src/main.rs
+20
-0
proto/generate.proto
proto/generate.proto
+2
-0
router/src/lib.rs
router/src/lib.rs
+5
-1
router/src/queue.rs
router/src/queue.rs
+1
-0
router/src/server.rs
router/src/server.rs
+1
-0
router/src/validation.rs
router/src/validation.rs
+2
-0
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+3
-1
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+3
-1
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+3
-1
server/text_generation/utils/tokens.py
server/text_generation/utils/tokens.py
+14
-3
server/text_generation/utils/watermark.py
server/text_generation/utils/watermark.py
+87
-0
No files found.
launcher/src/main.rs
View file @
9b8ea6a6
...
@@ -55,6 +55,10 @@ struct Args {
...
@@ -55,6 +55,10 @@ struct Args {
otlp_endpoint
:
Option
<
String
>
,
otlp_endpoint
:
Option
<
String
>
,
#[clap(long,
env)]
#[clap(long,
env)]
cors_allow_origin
:
Vec
<
String
>
,
cors_allow_origin
:
Vec
<
String
>
,
#[clap(long,
env)]
watermark_gamma
:
Option
<
f32
>
,
#[clap(long,
env)]
watermark_delta
:
Option
<
f32
>
,
}
}
fn
main
()
->
ExitCode
{
fn
main
()
->
ExitCode
{
...
@@ -88,6 +92,8 @@ fn main() -> ExitCode {
...
@@ -88,6 +92,8 @@ fn main() -> ExitCode {
json_output
,
json_output
,
otlp_endpoint
,
otlp_endpoint
,
cors_allow_origin
,
cors_allow_origin
,
watermark_gamma
,
watermark_delta
,
}
=
args
;
}
=
args
;
// Signal handler
// Signal handler
...
@@ -243,6 +249,8 @@ fn main() -> ExitCode {
...
@@ -243,6 +249,8 @@ fn main() -> ExitCode {
huggingface_hub_cache
,
huggingface_hub_cache
,
weights_cache_override
,
weights_cache_override
,
disable_custom_kernels
,
disable_custom_kernels
,
watermark_gamma
,
watermark_delta
,
otlp_endpoint
,
otlp_endpoint
,
status_sender
,
status_sender
,
shutdown
,
shutdown
,
...
@@ -414,6 +422,8 @@ fn shard_manager(
...
@@ -414,6 +422,8 @@ fn shard_manager(
huggingface_hub_cache
:
Option
<
String
>
,
huggingface_hub_cache
:
Option
<
String
>
,
weights_cache_override
:
Option
<
String
>
,
weights_cache_override
:
Option
<
String
>
,
disable_custom_kernels
:
bool
,
disable_custom_kernels
:
bool
,
watermark_gamma
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
otlp_endpoint
:
Option
<
String
>
,
otlp_endpoint
:
Option
<
String
>
,
status_sender
:
mpsc
::
Sender
<
ShardStatus
>
,
status_sender
:
mpsc
::
Sender
<
ShardStatus
>
,
shutdown
:
Arc
<
Mutex
<
bool
>>
,
shutdown
:
Arc
<
Mutex
<
bool
>>
,
...
@@ -494,6 +504,16 @@ fn shard_manager(
...
@@ -494,6 +504,16 @@ fn shard_manager(
env
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
env
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
}
}
// Watermark Gamma
if
let
Some
(
watermark_gamma
)
=
watermark_gamma
{
env
.push
((
"WATERMARK_GAMMA"
.into
(),
watermark_gamma
.to_string
()
.into
()))
}
// Watermark Delta
if
let
Some
(
watermark_delta
)
=
watermark_delta
{
env
.push
((
"WATERMARK_DELTA"
.into
(),
watermark_delta
.to_string
()
.into
()))
}
// Start process
// Start process
tracing
::
info!
(
"Starting shard {rank}"
);
tracing
::
info!
(
"Starting shard {rank}"
);
let
mut
p
=
match
Popen
::
create
(
let
mut
p
=
match
Popen
::
create
(
...
...
proto/generate.proto
View file @
9b8ea6a6
...
@@ -40,6 +40,8 @@ message NextTokenChooserParameters {
...
@@ -40,6 +40,8 @@ message NextTokenChooserParameters {
uint64
seed
=
5
;
uint64
seed
=
5
;
/// repetition penalty
/// repetition penalty
float
repetition_penalty
=
6
;
float
repetition_penalty
=
6
;
/// token watermarking using "A Watermark for Large Language Models"
bool
watermark
=
7
;
}
}
message
StoppingCriteriaParameters
{
message
StoppingCriteriaParameters
{
...
...
router/src/lib.rs
View file @
9b8ea6a6
...
@@ -53,6 +53,9 @@ pub(crate) struct GenerateParameters {
...
@@ -53,6 +53,9 @@ pub(crate) struct GenerateParameters {
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
pub
stop
:
Vec
<
String
>
,
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
#[serde(default)]
#[schema(default
=
"false"
,
example
=
true
)]
pub
watermark
:
bool
,
#[serde(default)]
#[schema(default
=
"true"
)]
#[schema(default
=
"true"
)]
pub
details
:
bool
,
pub
details
:
bool
,
#[serde(default)]
#[serde(default)]
...
@@ -72,7 +75,8 @@ fn default_parameters() -> GenerateParameters {
...
@@ -72,7 +75,8 @@ fn default_parameters() -> GenerateParameters {
do_sample
:
false
,
do_sample
:
false
,
max_new_tokens
:
default_max_new_tokens
(),
max_new_tokens
:
default_max_new_tokens
(),
return_full_text
:
None
,
return_full_text
:
None
,
stop
:
vec!
[],
stop
:
Vec
::
new
(),
watermark
:
false
,
details
:
false
,
details
:
false
,
seed
:
None
,
seed
:
None
,
}
}
...
...
router/src/queue.rs
View file @
9b8ea6a6
...
@@ -234,6 +234,7 @@ mod tests {
...
@@ -234,6 +234,7 @@ mod tests {
do_sample
:
false
,
do_sample
:
false
,
seed
:
0
,
seed
:
0
,
repetition_penalty
:
0.0
,
repetition_penalty
:
0.0
,
watermark
:
false
},
},
stopping_parameters
:
StoppingCriteriaParameters
{
stopping_parameters
:
StoppingCriteriaParameters
{
max_new_tokens
:
0
,
max_new_tokens
:
0
,
...
...
router/src/server.rs
View file @
9b8ea6a6
...
@@ -72,6 +72,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
...
@@ -72,6 +72,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
max_new_tokens
:
1
,
max_new_tokens
:
1
,
return_full_text
:
None
,
return_full_text
:
None
,
stop
:
Vec
::
new
(),
stop
:
Vec
::
new
(),
watermark
:
false
,
details
:
false
,
details
:
false
,
seed
:
None
,
seed
:
None
,
},
},
...
...
router/src/validation.rs
View file @
9b8ea6a6
...
@@ -157,6 +157,7 @@ fn validate(
...
@@ -157,6 +157,7 @@ fn validate(
max_new_tokens
,
max_new_tokens
,
stop
:
stop_sequences
,
stop
:
stop_sequences
,
seed
,
seed
,
watermark
,
..
..
}
=
request
.parameters
;
}
=
request
.parameters
;
...
@@ -232,6 +233,7 @@ fn validate(
...
@@ -232,6 +233,7 @@ fn validate(
top_p
,
top_p
,
do_sample
,
do_sample
,
seed
,
seed
,
watermark
,
};
};
let
stopping_parameters
=
StoppingCriteriaParameters
{
let
stopping_parameters
=
StoppingCriteriaParameters
{
max_new_tokens
,
max_new_tokens
,
...
...
server/text_generation/models/causal_lm.py
View file @
9b8ea6a6
...
@@ -67,7 +67,9 @@ class CausalLMBatch(Batch):
...
@@ -67,7 +67,9 @@ class CausalLMBatch(Batch):
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation/models/galactica.py
View file @
9b8ea6a6
...
@@ -100,7 +100,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -100,7 +100,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
)
stopping_criterias
.
append
(
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
)
...
...
server/text_generation/models/seq2seq_lm.py
View file @
9b8ea6a6
...
@@ -77,7 +77,9 @@ class Seq2SeqLMBatch(Batch):
...
@@ -77,7 +77,9 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
len
(
tokenizer
),
device
)
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
)
)
...
...
server/text_generation/utils/tokens.py
View file @
9b8ea6a6
...
@@ -13,6 +13,7 @@ from typing import List, Tuple, Optional
...
@@ -13,6 +13,7 @@ from typing import List, Tuple, Optional
from
text_generation.pb
import
generate_pb2
from
text_generation.pb
import
generate_pb2
from
text_generation.pb.generate_pb2
import
FinishReason
from
text_generation.pb.generate_pb2
import
FinishReason
from
text_generation.utils.watermark
import
WatermarkLogitsProcessor
class
Sampling
:
class
Sampling
:
...
@@ -35,6 +36,8 @@ class Greedy:
...
@@ -35,6 +36,8 @@ class Greedy:
class
NextTokenChooser
:
class
NextTokenChooser
:
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
,
watermark
=
False
,
temperature
=
1.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
,
repetition_penalty
=
1.0
,
top_k
=
None
,
top_k
=
None
,
...
@@ -47,6 +50,11 @@ class NextTokenChooser:
...
@@ -47,6 +50,11 @@ class NextTokenChooser:
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
# all samplers can be found in `generation_utils_samplers.py`
sampling
=
do_sample
sampling
=
do_sample
if
watermark
:
warpers
.
append
(
WatermarkLogitsProcessor
(
vocab_size
,
device
=
device
))
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
warpers
.
append
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
if
temperature
is
not
None
and
temperature
!=
1.0
:
if
temperature
is
not
None
and
temperature
!=
1.0
:
temperature
=
float
(
temperature
)
temperature
=
float
(
temperature
)
warpers
.
append
(
TemperatureLogitsWarper
(
temperature
))
warpers
.
append
(
TemperatureLogitsWarper
(
temperature
))
...
@@ -57,8 +65,6 @@ class NextTokenChooser:
...
@@ -57,8 +65,6 @@ class NextTokenChooser:
if
top_p
is
not
None
and
top_p
<
1.0
:
if
top_p
is
not
None
and
top_p
<
1.0
:
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
sampling
=
True
sampling
=
True
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
warpers
.
append
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
self
.
warpers
=
warpers
self
.
warpers
=
warpers
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
...
@@ -77,9 +83,14 @@ class NextTokenChooser:
...
@@ -77,9 +83,14 @@ class NextTokenChooser:
@
classmethod
@
classmethod
def
from_pb
(
def
from_pb
(
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
NextTokenChooserParameters
,
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
"NextTokenChooser"
:
)
->
"NextTokenChooser"
:
return
NextTokenChooser
(
return
NextTokenChooser
(
vocab_size
=
vocab_size
,
watermark
=
pb
.
watermark
,
temperature
=
pb
.
temperature
,
temperature
=
pb
.
temperature
,
repetition_penalty
=
pb
.
repetition_penalty
,
repetition_penalty
=
pb
.
repetition_penalty
,
top_k
=
pb
.
top_k
,
top_k
=
pb
.
top_k
,
...
...
server/text_generation/utils/watermark.py
0 → 100644
View file @
9b8ea6a6
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
torch
from
transformers
import
LogitsProcessor
GAMMA
=
os
.
getenv
(
"WATERMARK_GAMMA"
,
0.5
)
DELTA
=
os
.
getenv
(
"WATERMARK_DELTA"
,
2.0
)
class
WatermarkLogitsProcessor
(
LogitsProcessor
):
def
__init__
(
self
,
vocab_size
:
int
,
gamma
:
float
=
GAMMA
,
delta
:
float
=
DELTA
,
hash_key
:
int
=
15485863
,
# just a large prime number to create a rng seed with sufficient bit width
device
:
str
=
"cpu"
,
):
# watermarking parameters
self
.
vocab_size
=
vocab_size
self
.
gamma
=
gamma
self
.
delta
=
delta
self
.
rng
=
torch
.
Generator
(
device
=
device
)
self
.
hash_key
=
hash_key
def
_seed_rng
(
self
,
input_ids
:
torch
.
LongTensor
)
->
None
:
assert
(
input_ids
.
shape
[
-
1
]
>=
1
),
"requires at least a 1 token prefix sequence to seed rng"
prev_token
=
input_ids
[
-
1
].
item
()
self
.
rng
.
manual_seed
(
self
.
hash_key
*
prev_token
)
def
_get_greenlist_ids
(
self
,
input_ids
:
torch
.
LongTensor
)
->
list
[
int
]:
# seed the rng using the previous tokens/prefix
self
.
_seed_rng
(
input_ids
)
greenlist_size
=
int
(
self
.
vocab_size
*
self
.
gamma
)
vocab_permutation
=
torch
.
randperm
(
self
.
vocab_size
,
device
=
input_ids
.
device
,
generator
=
self
.
rng
)
greenlist_ids
=
vocab_permutation
[:
greenlist_size
]
return
greenlist_ids
@
staticmethod
def
_calc_greenlist_mask
(
scores
:
torch
.
FloatTensor
,
greenlist_token_ids
)
->
torch
.
BoolTensor
:
green_tokens_mask
=
torch
.
zeros_like
(
scores
)
green_tokens_mask
[
-
1
,
greenlist_token_ids
]
=
1
final_mask
=
green_tokens_mask
.
bool
()
return
final_mask
@
staticmethod
def
_bias_greenlist_logits
(
scores
:
torch
.
Tensor
,
greenlist_mask
:
torch
.
Tensor
,
greenlist_bias
:
float
)
->
torch
.
Tensor
:
scores
[
greenlist_mask
]
=
scores
[
greenlist_mask
]
+
greenlist_bias
return
scores
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
assert
len
(
input_ids
)
==
1
greenlist_ids
=
self
.
_get_greenlist_ids
(
input_ids
[
0
])
green_tokens_mask
=
self
.
_calc_greenlist_mask
(
scores
=
scores
,
greenlist_token_ids
=
greenlist_ids
)
scores
=
self
.
_bias_greenlist_logits
(
scores
=
scores
,
greenlist_mask
=
green_tokens_mask
,
greenlist_bias
=
self
.
delta
)
return
scores
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment