Unverified Commit 779bc360 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Watermark: fix tests (#30961)



* fix tests

* style

* Update tests/generation/test_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent a3c7b59e
...@@ -2148,6 +2148,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2148,6 +2148,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash") watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15) _ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)
# We will not check watermarked text, since we check it in `logits_processors` tests
# Checking if generated ids are as expected fails on different hardware
args = { args = {
"bias": 2.0, "bias": 2.0,
"context_width": 1, "context_width": 1,
...@@ -2158,19 +2160,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2158,19 +2160,11 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
output = model.generate(**model_inputs, do_sample=False, max_length=15) output = model.generate(**model_inputs, do_sample=False, max_length=15)
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15) output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)
# check that the watermarked text is generating what is should # Check that the detector is detecting watermarked text
self.assertListEqual(
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
)
self.assertListEqual(
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
)
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args) detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True) detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
detection_out = detector(output[:, input_len:], return_dict=True) detection_out = detector(output[:, input_len:], return_dict=True)
# check that the detector is detecting watermarked text
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False]) self.assertListEqual(detection_out.prediction.tolist(), [False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment