"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "51e43b6143969ea6570d5873cdcf3e430ac9b73e"
Unverified Commit d4dc4d76 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[chore] misc changes in the bnb tests for consistency. (#11355)

misc changes in the bnb tests for consistency.
parent 3a31b291
...@@ -526,7 +526,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -526,7 +526,7 @@ class SlowBnb4BitTests(Base4bitTests):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.", reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True, strict=True,
) )
def test_pipeline_device_placement_works_with_nf4(self): def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig( transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
...@@ -560,7 +560,7 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -560,7 +560,7 @@ class SlowBnb4BitTests(Base4bitTests):
).to(torch_device) ).to(torch_device)
# Check if inference works. # Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2) _ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_4bit del pipeline_4bit
......
...@@ -492,7 +492,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -492,7 +492,7 @@ class SlowBnb8bitTests(Base8bitTests):
self.assertTrue(max_diff < 1e-2) self.assertTrue(max_diff < 1e-2)
# 8bit models cannot be offloaded to CPU. # 8bit models cannot be offloaded to CPU.
self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device)
# calling it again shouldn't be a problem # calling it again shouldn't be a problem
_ = self.pipeline_8bit( _ = self.pipeline_8bit(
prompt=self.prompt, prompt=self.prompt,
...@@ -534,7 +534,7 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -534,7 +534,7 @@ class SlowBnb8bitTests(Base8bitTests):
).to(device) ).to(device)
# Check if inference works. # Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2) _ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_8bit del pipeline_8bit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment