Unverified Commit d4dc4d76 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[chore] misc changes in the bnb tests for consistency. (#11355)

misc changes in the bnb tests for consistency.
parent 3a31b291
......@@ -526,7 +526,7 @@ class SlowBnb4BitTests(Base4bitTests):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_device_placement_works_with_nf4(self):
def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
......@@ -560,7 +560,7 @@ class SlowBnb4BitTests(Base4bitTests):
).to(torch_device)
# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
_ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_4bit
......
......@@ -492,7 +492,7 @@ class SlowBnb8bitTests(Base8bitTests):
self.assertTrue(max_diff < 1e-2)
# 8bit models cannot be offloaded to CPU.
self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device)
# calling it again shouldn't be a problem
_ = self.pipeline_8bit(
prompt=self.prompt,
......@@ -534,7 +534,7 @@ class SlowBnb8bitTests(Base8bitTests):
).to(device)
# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
_ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_8bit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment