Unverified Commit 8891193e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[Pipeline] fix failing bloom `pipeline` test (#20778)

fix failing `pipeline` test
parent b9b70b0e
...@@ -284,10 +284,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -284,10 +284,10 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
], ],
) )
# torch_dtype not necessary # torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto") pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto")
self.assertEqual(pipe.model.device, torch.device(0)) self.assertEqual(pipe.model.device, torch.device(0))
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
out = pipe("This is a test") out = pipe("This is a test")
self.assertEqual( self.assertEqual(
out, out,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment