"docs/vscode:/vscode.git/clone" did not exist on "28247e78819ab9756b81f8df39611c333d099400"
Unverified Commit 5ad960f1 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Add Watermarking LogitsProcessor and WatermarkDetector (#29676)



* add watermarking processor

* remove the other hashing (context width=1 always)

* make style

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* update watermarking process

* add detector

* update tests to use detector

* fix failing tests

* rename `input_seq`

* make style

* doc for processor

* minor fixes

* docs

* make quality

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/logits_process.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* add PR suggestions

* let's use lru_cache's default max size (128)

* import processor if torch available

* maybe like this

* lets move the config to torch independet file

* add docs

* tiny docs fix to make the test happy

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/watermarking.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* PR suggestions

* add docs

* fix test

* fix docs

* address pr comments

* style

* Revert "style"

This reverts commit 7f33cc34ff08b414f8e7f90060889877606b43b2.

* correct style

* make doctest green

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 65ea1904
......@@ -173,6 +173,55 @@ your screen, one word at a time:
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
## Watermarking
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
the inner functioning of watermarking, it is recommended to refer to the paper.
The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
to detect watermarked text. To trigger watermarking, pass in a [`WatermarkingConfig`] with needed arguments directly to the
`.generate()` method or add it to the [`GenerationConfig`]. Watermarked text can be later detected with a [`WatermarkDetector`].
<Tip warning={true}>
The WatermarkDetector internally relies on the proportion of "green" tokens, and whether generated text follows the coloring pattern.
That is why it is recommended to strip off the prompt text, if it is much longer than the generated text.
This also can have an effect when one sequence in the batch is a lot longer causing other rows to be padded.
Additionally, the detector **must** be initiated with identical watermark configuration arguments used when generating.
</Tip>
Let's generate some text with watermarking. In the below code snippet, we set the bias to 2.5 which is a value that
will be added to "green" tokens' logits. After generating watermarked text, we can pass it directly to the `WatermarkDetector`
to check if the text is machine-generated (outputs `True` for machine-generated and `False` otherwise).
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> tok.pad_token_id = tok.eos_token_id
>>> tok.padding_side = "left"
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
>>> input_len = inputs["input_ids"].shape[-1]
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
>>> out = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
>>> detection_out = detector(out, return_dict=True)
>>> detection_out.prediction
array([True, True])
```
## Decoding strategies
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
......
......@@ -209,6 +209,10 @@ generation.
[[autodoc]] WhisperTimeStampLogitsProcessor
- __call__
[[autodoc]] WatermarkLogitsProcessor
- __call__
### TensorFlow
[[autodoc]] TFForcedBOSTokenLogitsProcessor
......@@ -372,3 +376,10 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- update
- get_seq_length
- reorder_cache
## Watermark Utils
[[autodoc]] WatermarkDetector
- __call__
......@@ -41,6 +41,8 @@ like token streaming.
- validate
- get_generation_mode
[[autodoc]] generation.WatermarkingConfig
## GenerationMixin
[[autodoc]] generation.GenerationMixin
......
......@@ -117,7 +117,12 @@ _import_structure = {
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"generation": [
"GenerationConfig",
"TextIteratorStreamer",
"TextStreamer",
"WatermarkingConfig",
],
"hf_argparser": ["HfArgumentParser"],
"hyperparameter_search": [],
"image_transforms": [],
......@@ -1232,6 +1237,8 @@ else:
"TopPLogitsWarper",
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WatermarkDetector",
"WatermarkLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
]
)
......@@ -4617,7 +4624,7 @@ if TYPE_CHECKING:
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser
# Integrations
......@@ -5797,6 +5804,8 @@ if TYPE_CHECKING:
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkDetector,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .modeling_utils import PreTrainedModel
......
......@@ -18,7 +18,7 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_availab
_import_structure = {
"configuration_utils": ["GenerationConfig", "GenerationMode"],
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}
......@@ -78,6 +78,7 @@ else:
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
"WatermarkLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
......@@ -106,6 +107,10 @@ else:
"GenerateDecoderOnlyOutput",
"GenerateEncoderDecoderOutput",
]
_import_structure["watermarking"] = [
"WatermarkDetector",
"WatermarkDetectorOutput",
]
try:
if not is_tf_available():
......@@ -174,7 +179,7 @@ else:
]
if TYPE_CHECKING:
from .configuration_utils import GenerationConfig, GenerationMode
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
from .streamers import TextIteratorStreamer, TextStreamer
try:
......@@ -218,6 +223,7 @@ if TYPE_CHECKING:
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
......@@ -247,6 +253,10 @@ if TYPE_CHECKING:
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
from .watermarking import (
WatermarkDetector,
WatermarkDetectorOutput,
)
try:
if not is_tf_available():
......
......@@ -18,6 +18,7 @@ import copy
import json
import os
import warnings
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from .. import __version__
......@@ -221,6 +222,23 @@ class GenerationConfig(PushToHubMixin):
low_memory (`bool`, *optional*):
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.
watermarking_config (Union[`WatermarkingConfig`, `dict`], *optional*):
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width(`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
> Parameters that define the output variables of generate
......@@ -333,6 +351,13 @@ class GenerationConfig(PushToHubMixin):
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)
watermarking_config = kwargs.pop("watermarking_config", None)
if watermarking_config is None:
self.watermarking_config = None
elif isinstance(watermarking_config, WatermarkingConfig):
self.watermarking_config = watermarking_config
else:
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)
# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
......@@ -613,6 +638,12 @@ class GenerationConfig(PushToHubMixin):
f"({self.num_beams})."
)
# check watermarking arguments
if self.watermarking_config is not None:
if not isinstance(self.watermarking_config, WatermarkingConfig):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
......@@ -1021,7 +1052,16 @@ class GenerationConfig(PushToHubMixin):
else:
return obj
def convert_dataclass_to_dict(obj):
if isinstance(obj, dict):
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
elif is_dataclass(obj):
return obj.to_dict()
else:
return obj
config_dict = convert_keys_to_string(config_dict)
config_dict = convert_dataclass_to_dict(config_dict)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
......@@ -1093,3 +1133,142 @@ class GenerationConfig(PushToHubMixin):
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
@dataclass
class WatermarkingConfig:
"""
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.
Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width(`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
"""
def __init__(
self,
greenlist_ratio: Optional[float] = 0.25,
bias: Optional[float] = 2.0,
hashing_key: Optional[int] = 15485863,
seeding_scheme: Optional[str] = "lefthash",
context_width: Optional[int] = 1,
):
self.greenlist_ratio = greenlist_ratio
self.bias = bias
self.hashing_key = hashing_key
self.seeding_scheme = seeding_scheme
self.context_width = context_width
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a WatermarkingConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
return output
def __iter__(self):
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"
def update(self, **kwargs):
"""
Update the configuration attributes with new values.
Args:
**kwargs: Keyword arguments representing configuration attributes and their new values.
"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
def validate(self):
watermark_missing_arg_msg = (
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
if self.seeding_scheme not in ["selfhash", "lefthash"]:
raise ValueError(
watermark_missing_arg_msg.format(
key="seeding_scheme",
correct_value="[`selfhash`, `lefthash`]",
found_value=self.seeding_scheme,
),
)
if not 0.0 <= self.greenlist_ratio <= 1.0:
raise ValueError(
watermark_missing_arg_msg.format(
key="greenlist_ratio",
correct_value="in range between 0.0 and 1.0",
found_value=self.seeding_scheme,
),
)
if not self.context_width >= 1:
raise ValueError(
watermark_missing_arg_msg.format(
key="context_width",
correct_value="a positive integer",
found_value=self.context_width,
),
)
......@@ -2321,3 +2321,144 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
scores_processed = torch.where(do_early_stop, early_stop_scores, scores)
return scores_processed
class WatermarkLogitsProcessor(LogitsProcessor):
r"""
Logits processor for watermarking generated text. The processor modifies model output scores by adding a small bias to
randomized set of "green" tokens before generating the next token. "Green" tokens selection process depends on the
`seeding_scheme` used. The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
The text generated by this `LogitsProcessor` can be detected using `WatermarkDetector`. See [`~WatermarkDetector.__call__`] for details,
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
Args:
vocab_size (`int`):
The model tokenizer's vocab_size. Used to calculate "green" tokens ratio.
device (`str`):
The device where model is allocated.
greenlist_ratio (`float`, optional, *optional*, defaults to 0.25):
The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
bias (`float`, optional, *optional*, defaults to 2.0):
The bias added to the selected "green" tokens' logits. Consider lowering the
`bias` if the text generation quality degrades. Recommended values are in the
range of [0.5, 2.0]. Defaults to 2.0.
hashing_key (`int`, optional, *optional*, defaults to 15485863):
Key used for hashing. If you deploy this watermark, we advise using another private key.
Defaults to 15485863 (the millionth prime).
seeding_scheme (`str`, optional, *optional*, defaults to `"lefthash"`):
The seeding scheme used for selecting "green" tokens. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
context_width (`int`, *optional*, defaults to 1):
The number of previous tokens to use when setting the seed.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkingConfig
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(["Alice and Bob are"], return_tensors="pt")
>>> # normal generation
>>> out = model.generate(inputs["input_ids"], max_length=20, do_sample=False)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
'Alice and Bob are both in the same room.\n\n"I\'m not sure if you\'re'
>>> # watermarked generation
>>> watermarking_config = WatermarkingConfig(bias=2.5, context_width=2, seeding_scheme="selfhash")
>>> out = model.generate(inputs["input_ids"], watermarking_config=watermarking_config, max_length=20, do_sample=False)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
'Alice and Bob are both still alive and well and the story is pretty much a one-hour adventure'
>>> # to detect watermarked text use the WatermarkDetector class
>>> from transformers import WatermarkDetector
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config= watermarking_config)
>>> detection_preds = detector(out)
>>> detection_preds
array([ True])
```
"""
def __init__(
self,
vocab_size,
device,
greenlist_ratio: float = 0.25,
bias: float = 2.0,
hashing_key: int = 15485863,
seeding_scheme: str = "lefthash",
context_width: int = 1,
):
if seeding_scheme not in ["selfhash", "lefthash"]:
raise ValueError(f"seeding_scheme has to be one of [`selfhash`, `lefthash`], but found {seeding_scheme}")
if greenlist_ratio >= 1.0 or greenlist_ratio <= 0.0:
raise ValueError(
f"greenlist_ratio has be in range between 0.0 and 1.0, exclusively. but found {greenlist_ratio}"
)
self.vocab_size = vocab_size
self.greenlist_size = int(self.vocab_size * greenlist_ratio)
self.bias = bias
self.seeding_scheme = seeding_scheme
self.rng = torch.Generator(device=device)
self.hash_key = hashing_key
self.context_width = context_width
self.rng.manual_seed(hashing_key)
self.table_size = 1_000_003
self.fixed_table = torch.randperm(self.table_size, generator=self.rng, device=device)
def set_seed(self, input_seq: torch.LongTensor):
input_seq = input_seq[-self.context_width :]
if self.seeding_scheme == "selfhash":
a = self.fixed_table[input_seq % self.table_size] + 1
b = self.fixed_table[input_seq[-1] % self.table_size] + 1
seed = (self.hash_key * a * b).min().item()
else:
seed = self.hash_key * input_seq[-1].item()
self.rng.manual_seed(seed % (2**64 - 1))
def _get_greenlist_ids(self, input_seq: torch.LongTensor) -> torch.LongTensor:
self.set_seed(input_seq)
vocab_permutation = torch.randperm(self.vocab_size, device=input_seq.device, generator=self.rng)
greenlist_ids = vocab_permutation[: self.greenlist_size]
return greenlist_ids
def _score_rejection_sampling(self, input_seq: torch.LongTensor, scores: torch.FloatTensor) -> torch.LongTensor:
"""
Generate greenlist based on current candidate next token. Reject and move on if necessary.
Runs for a fixed number of steps only for efficiency, since the methods is not batched.
"""
final_greenlist = []
_, greedy_predictions = scores.sort(dim=-1, descending=True)
# 40 is an arbitrary number chosen to save compute and not run for long (taken from orig repo)
for i in range(40):
greenlist_ids = self._get_greenlist_ids(torch.cat([input_seq, greedy_predictions[i, None]], dim=-1))
if greedy_predictions[i] in greenlist_ids:
final_greenlist.append(greedy_predictions[i])
return torch.tensor(final_greenlist, device=input_seq.device)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[-1] < self.context_width:
logger.warning(
f"`input_ids` should have at least `{self.context_width}` tokens but has {input_ids.shape[-1]}. "
"The seeding will be skipped for this generation step!"
)
return scores
scores_processed = scores.clone()
for b_idx, input_seq in enumerate(input_ids):
if self.seeding_scheme == "selfhash":
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
else:
greenlist_ids = self._get_greenlist_ids(input_seq)
scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias
return scores_processed
......@@ -74,6 +74,7 @@ from .logits_process import (
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from .stopping_criteria import (
EosTokenCriteria,
......@@ -763,6 +764,7 @@ class GenerationMixin:
encoder_input_ids: torch.LongTensor,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
logits_processor: Optional[LogitsProcessorList],
device: str = None,
model_kwargs: Optional[Dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
......@@ -879,6 +881,18 @@ class GenerationMixin:
FutureWarning,
)
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
if generation_config.watermarking_config is not None:
processors.append(
WatermarkLogitsProcessor(
vocab_size=self.config.vocab_size,
device=device,
greenlist_ratio=generation_config.watermarking_config.greenlist_ratio,
bias=generation_config.watermarking_config.bias,
hashing_key=generation_config.watermarking_config.hashing_key,
seeding_scheme=generation_config.watermarking_config.seeding_scheme,
context_width=generation_config.watermarking_config.context_width,
)
)
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
......@@ -1632,6 +1646,7 @@ class GenerationMixin:
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, Optional, Union
import numpy as np
from ..configuration_utils import PretrainedConfig
from ..utils import is_torch_available, logging
from .configuration_utils import WatermarkingConfig
if is_torch_available():
import torch
from .logits_process import WatermarkLogitsProcessor
logger = logging.get_logger(__name__)
@dataclass
class WatermarkDetectorOutput:
"""
Outputs of a watermark detector.
Args:
num_tokens_scored (np.array of shape (batch_size)):
Array containing the number of tokens scored for each element in the batch.
num_green_tokens (np.array of shape (batch_size)):
Array containing the number of green tokens for each element in the batch.
green_fraction (np.array of shape (batch_size)):
Array containing the fraction of green tokens for each element in the batch.
z_score (np.array of shape (batch_size)):
Array containing the z-score for each element in the batch. Z-score here shows
how many standard deviations away is the green token count in the input text
from the expected green token count for machine-generated text.
p_value (np.array of shape (batch_size)):
Array containing the p-value for each batch obtained from z-scores.
prediction (np.array of shape (batch_size)), *optional*:
Array containing boolean predictions whether a text is machine-generated for each element in the batch.
confidence (np.array of shape (batch_size)), *optional*:
Array containing confidence scores of a text being machine-generated for each element in the batch.
"""
num_tokens_scored: np.array = None
num_green_tokens: np.array = None
green_fraction: np.array = None
z_score: np.array = None
p_value: np.array = None
prediction: Optional[np.array] = None
confidence: Optional[np.array] = None
class WatermarkDetector:
r"""
Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
Args:
model_config (`PretrainedConfig`):
The model config that will be used to get model specific arguments used when generating.
device (`str`):
The device which was used during watermarked text generation.
watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
The exact same watermarking config and arguments used when generating text.
ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
Whether to count every unique ngram only once or not.
max_cache_size (`int`, *optional*, defaults to 128):
The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
>>> model_id = "openai-community/gpt2"
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
>>> tok = AutoTokenizer.from_pretrained(model_id)
>>> tok.pad_token_id = tok.eos_token_id
>>> tok.padding_side = "left"
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
>>> input_len = inputs["input_ids"].shape[-1]
>>> # first generate text with watermark and without
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
>>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
>>> out = model.generate(**inputs, do_sample=False, max_length=20)
>>> # now we can instantiate the detector and check the generated text
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
>>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
>>> detection_out = detector(out, return_dict=True)
>>> detection_out_watermarked.prediction
array([ True, True])
>>> detection_out.prediction
array([False, False])
```
"""
def __init__(
self,
model_config: PretrainedConfig,
device: str,
watermarking_config: Union[WatermarkingConfig, Dict],
ignore_repeated_ngrams: bool = False,
max_cache_size: int = 128,
):
if isinstance(watermarking_config, WatermarkingConfig):
watermarking_config = watermarking_config.to_dict()
self.bos_token_id = (
model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
)
self.greenlist_ratio = watermarking_config["greenlist_ratio"]
self.ignore_repeated_ngrams = ignore_repeated_ngrams
self.processor = WatermarkLogitsProcessor(
vocab_size=model_config.vocab_size, device=device, **watermarking_config
)
# Expensive re-seeding and sampling is cached.
self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
greenlist_ids = self.processor._get_greenlist_ids(prefix)
return target in greenlist_ids
def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
batch_size, seq_length = input_ids.shape
selfhash = int(self.processor.seeding_scheme == "selfhash")
n = self.processor.context_width + 1 - selfhash
indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
ngram_tensors = input_ids[:, indices]
num_tokens_scored_batch = np.zeros(batch_size)
green_token_count_batch = np.zeros(batch_size)
for batch_idx in range(ngram_tensors.shape[0]):
frequencies_table = collections.Counter(ngram_tensors[batch_idx])
ngram_to_watermark_lookup = {}
for ngram_example in frequencies_table.keys():
prefix = ngram_example if selfhash else ngram_example[:-1]
target = ngram_example[-1]
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
if self.ignore_repeated_ngrams:
# counts a green/red hit once per unique ngram.
# num total tokens scored becomes the number unique ngrams.
num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
else:
num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
green_token_count_batch[batch_idx] = sum(
freq * outcome
for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
)
return num_tokens_scored_batch, green_token_count_batch
def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
expected_count = self.greenlist_ratio
numer = green_token_count - expected_count * total_num_tokens
denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
z = numer / denom
return z
def _compute_pval(self, x, loc=0, scale=1):
z = (x - loc) / scale
return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
def __call__(
self,
input_ids: torch.LongTensor,
z_threshold: float = 3.0,
return_dict: bool = False,
) -> Union[WatermarkDetectorOutput, np.array]:
"""
Args:
input_ids (`torch.LongTensor`):
The watermark generated text. It is advised to remove the prompt, which can affect the detection.
z_threshold (`Dict`, *optional*, defaults to `3.0`):
Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
sensitivity and vice versa for lower z threshold.
return_dict (`bool`, *optional*, defaults to `False`):
Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
ma
Return:
[`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`]
if `return_dict=True` otherwise a `np.array`.
"""
# Let's assume that if one batch start with `bos`, all batched also do
if input_ids[0, 0] == self.bos_token_id:
input_ids = input_ids[:, 1:]
if input_ids.shape[-1] - self.processor.context_width < 1:
raise ValueError(
f"Must have at least `1` token to score after the first "
f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
)
num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
prediction = z_score > z_threshold
if return_dict:
p_value = self._compute_pval(z_score)
confidence = 1 - p_value
return WatermarkDetectorOutput(
num_tokens_scored=num_tokens_scored,
num_green_tokens=green_token_count,
green_fraction=green_token_count / num_tokens_scored,
z_score=z_score,
p_value=p_value,
prediction=prediction,
confidence=confidence,
)
return prediction
......@@ -422,6 +422,20 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class WatermarkDetector(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WatermarkLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -53,6 +53,7 @@ if is_torch_available():
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
......@@ -949,3 +950,26 @@ class LogitsProcessorTest(unittest.TestCase):
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)
def test_watermarking_processor(self):
batch_size = 3
vocab_size = 20
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
# raise error if incorrect seeding_scheme is passed
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")
# raise error if the greenlist_ratio in not in range (0.0, 1.0)
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)
watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)
# use fixed id for last token, needed for reprodicibility and tests
input_ids[:, -1] = 10
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
......@@ -75,6 +75,8 @@ if is_torch_available():
SampleEncoderDecoderOutput,
StoppingCriteria,
StoppingCriteriaList,
WatermarkDetector,
WatermarkingConfig,
)
from transformers.generation.utils import _speculative_sampling
......@@ -2098,6 +2100,44 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@slow
def test_watermark_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)
input_len = model_inputs["input_ids"].shape[-1]
# generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :)
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
args = {
"bias": 2.0,
"context_width": 1,
"seeding_scheme": "selfhash",
"greenlist_ratio": 0.25,
"hashing_key": 15485863,
}
output = model.generate(**model_inputs, do_sample=False, max_length=15)
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
# check that the watermarked text is generating what is should
self.assertListEqual(
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
)
self.assertListEqual(
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
)
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
detection_out = detector(output[:, input_len:], return_dict=True)
# check that the detector is detecting watermarked text
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False])
@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment