Unverified Commit 9c772ac8 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Quantization / HQQ: Fix HQQ tests on our runner (#30668)

Update test_hqq.py
parent a45c5148
...@@ -35,7 +35,7 @@ if is_hqq_available(): ...@@ -35,7 +35,7 @@ if is_hqq_available():
class HQQLLMRunner: class HQQLLMRunner:
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir): def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir=None):
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
...@@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase): ...@@ -118,7 +118,7 @@ class HQQTest(unittest.TestCase):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model) check_forward(self, hqq_runner.model)
def test_bfp16_quantized_model_with_offloading(self): def test_f16_quantized_model_with_offloading(self):
""" """
Simple LLM model testing bfp16 with meta-data offloading Simple LLM model testing bfp16 with meta-data offloading
""" """
...@@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase): ...@@ -137,7 +137,7 @@ class HQQTest(unittest.TestCase):
) )
hqq_runner = HQQLLMRunner( hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
) )
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment