"vscode:/vscode.git/clone" did not exist on "13f6630a9ea78bee4bd80bb6e842e55e374eec9a"
Unverified Commit d910816c authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix] Automatically Detect SparseML models (#5119)

parent 87d41c84
......@@ -156,6 +156,17 @@ class ModelConfig:
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# SparseML uses a "compression_config" with a "quantization_config".
compression_cfg = getattr(self.hf_config, "compression_config",
None)
if compression_cfg is not None:
quant_cfg = compression_cfg.get("quantization_config", None)
return quant_cfg
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
......@@ -163,12 +174,13 @@ class ModelConfig:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = getattr(self.hf_config, "quantization_config", None)
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
for name, method in QUANTIZATION_METHODS.items():
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment