Unverified Commit fa7f3cf3 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] enable test_pipeline_accelerate_top_p on XPU (#29309)



* use torch_device

* Update tests/pipelines/test_pipelines_text_generation.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix style

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent ebccb091
...@@ -450,7 +450,9 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -450,7 +450,9 @@ class TextGenerationPipelineTests(unittest.TestCase):
def test_pipeline_accelerate_top_p(self): def test_pipeline_accelerate_top_p(self):
import torch import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16) pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16
)
pipe("This is a test", do_sample=True, top_p=0.5) pipe("This is a test", do_sample=True, top_p=0.5)
def test_pipeline_length_setting_warning(self): def test_pipeline_length_setting_warning(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment