import json import os import tempfile import time import unittest from auto_gptq import AutoGPTQForCausalLM from auto_gptq.quantization import CHECKPOINT_FORMAT, CHECKPOINT_FORMAT_FIELD, QUANT_CONFIG_FILENAME from auto_gptq.quantization.config import QUANT_METHOD, BaseQuantizeConfig class TestSerialization(unittest.TestCase): MODEL_ID = "habanoz/TinyLlama-1.1B-Chat-v0.3-GPTQ" def setUp(self): dummy_config = BaseQuantizeConfig( model_name_or_path=self.MODEL_ID, quant_method=QUANT_METHOD.GPTQ, checkpoint_format=CHECKPOINT_FORMAT.MARLIN) model_cache_path, is_cached = dummy_config.get_cache_file_path() if is_cached: os.remove(model_cache_path) def test_marlin_local_serialization(self): start = time.time() model = AutoGPTQForCausalLM.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True) end = time.time() first_load_time = end - start with tempfile.TemporaryDirectory() as tmpdir: model.save_pretrained(tmpdir) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "model.safetensors"))) model_cache_path, is_cached = model.quantize_config.get_cache_file_path() self.assertFalse(os.path.isfile(os.path.join(tmpdir, model_cache_path))) with open(os.path.join(tmpdir, QUANT_CONFIG_FILENAME), "r") as config_file: config = json.load(config_file) self.assertTrue(config[CHECKPOINT_FORMAT_FIELD] == CHECKPOINT_FORMAT.MARLIN) start = time.time() model = AutoGPTQForCausalLM.from_quantized(tmpdir, device="cuda:0", use_marlin=True) end = time.time() second_load_time = end - start # Since we use a CUDA kernel to repack weights, the first load time is already small. self.assertTrue(second_load_time < first_load_time) def test_marlin_hf_cache_serialization(self): start = time.time() model = AutoGPTQForCausalLM.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True) self.assertTrue(model.quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN) end = time.time() first_load_time = end - start model_cache_path, is_cached = model.quantize_config.get_cache_file_path() self.assertTrue("assets" in model_cache_path) self.assertTrue(is_cached) start = time.time() model = AutoGPTQForCausalLM.from_quantized(self.MODEL_ID, device="cuda:0", use_marlin=True) self.assertTrue(model.quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN) end = time.time() second_load_time = end - start # Since we use a CUDA kernel to repack weights, the first load time is already small. self.assertTrue(second_load_time < first_load_time)