Unverified Commit 31acba56 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `PvtModelIntegrationTest::test_inference_fp16` (#25106)



update
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent ee63520a
......@@ -317,14 +317,13 @@ class PvtModelIntegrationTest(unittest.TestCase):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = PvtForImageClassification.from_pretrained(
"Zetatech/pvt-tiny-224", torch_dtype=torch.float16, device_map="auto"
)
model = PvtForImageClassification.from_pretrained("Zetatech/pvt-tiny-224", torch_dtype=torch.float16)
model.to(torch_device)
image_processor = PvtImageProcessor(size=224)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device).astype(torch.float16)
pixel_values = inputs.pixel_values.to(torch_device, dtype=torch.float16)
# forward pass to make sure inference works in fp16
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment