Unverified Commit 18546378 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] make 2 tests device-agnostic (#30008)

add torch device
parent bb76f81e
...@@ -992,7 +992,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -992,7 +992,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
# prepare image # prepare image
image = prepare_img() image = prepare_img()
inputs = processor(images=image, return_tensors="pt").to(0, dtype=torch.float16) inputs = processor(images=image, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)
predictions = model.generate(**inputs) predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
...@@ -1003,7 +1003,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -1003,7 +1003,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
# image and context # image and context
prompt = "Question: which city is this? Answer:" prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(0, dtype=torch.float16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(f"{torch_device}:0", dtype=torch.float16)
predictions = model.generate(**inputs) predictions = model.generate(**inputs)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
......
...@@ -776,7 +776,7 @@ class ModelUtilsTest(TestCasePlus): ...@@ -776,7 +776,7 @@ class ModelUtilsTest(TestCasePlus):
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
inputs = tokenizer("Hello, my name is", return_tensors="pt") inputs = tokenizer("Hello, my name is", return_tensors="pt")
output = model.generate(inputs["input_ids"].to(0)) output = model.generate(inputs["input_ids"].to(f"{torch_device}:0"))
text_output = tokenizer.decode(output[0].tolist()) text_output = tokenizer.decode(output[0].tolist())
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm") self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment