Unverified Commit 130b9878 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[XGLM] run sampling test on CPU to be deterministic (#15892)

* run sampling test on CPU to be deterministic

* input_ids on CPU
parent baab5e7c
...@@ -418,15 +418,14 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase): ...@@ -418,15 +418,14 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
def test_xglm_sample(self): def test_xglm_sample(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M") tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt") tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
input_ids = tokenized.input_ids.to(torch_device) input_ids = tokenized.input_ids
output_ids = model.generate(input_ids, do_sample=True, num_beams=1) output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my" EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR) self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment