Unverified Commit 9c772ac8 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Quantization / HQQ: Fix HQQ tests on our runner (#30668)

Update test_hqq.py
parent a45c5148
......@@ -35,7 +35,7 @@ if is_hqq_available():
class HQQLLMRunner:
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir):
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None):
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=compute_dtype,
......@@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_bfp16_quantized_model_with_offloading(self):
def test_f16_quantized_model_with_offloading(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
"""
......@@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment