Unverified Commit 172bcf01 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "Refactor kv_cache_scheme handling for quantization (#10132)" (#10935)

parent 37158f20
...@@ -140,21 +140,11 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -140,21 +140,11 @@ class ModelOptFp8Config(QuantizationConfig):
# Flat format (config.json quantization_config) # Flat format (config.json quantization_config)
# For kv_cache, check if kv_cache_scheme exists and extract algo # For kv_cache, check if kv_cache_scheme exists and extract algo
kv_cache_scheme = config.get("kv_cache_scheme") kv_cache_scheme = config.get("kv_cache_scheme")
if (
kv_cache_type = None kv_cache_scheme
kv_cache_bits = None and kv_cache_scheme.get("type") == "float"
if isinstance(kv_cache_scheme, dict): and kv_cache_scheme.get("num_bits") == 8
# Handles the expected format: {"type": "float", "num_bits": 8} ):
kv_cache_type = kv_cache_scheme.get("type")
kv_cache_bits = kv_cache_scheme.get("num_bits")
elif isinstance(kv_cache_scheme, str):
# Handles the shorthand format: "FP8"
if kv_cache_scheme.upper() == "FP8":
kv_cache_type = "float"
kv_cache_bits = 8
# Now, safely use the extracted values
if kv_cache_type == "float" and kv_cache_bits == 8:
kv_cache_quant_method = "FP8" kv_cache_quant_method = "FP8"
# Map 'ignore' field to 'exclude_modules' # Map 'ignore' field to 'exclude_modules'
...@@ -604,22 +594,11 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -604,22 +594,11 @@ class ModelOptFp4Config(QuantizationConfig):
if not kv_cache_quant_algo: if not kv_cache_quant_algo:
# For config.json format, derive from kv_cache_scheme if available # For config.json format, derive from kv_cache_scheme if available
kv_cache_scheme = config.get("kv_cache_scheme") kv_cache_scheme = config.get("kv_cache_scheme")
if (
kv_cache_type = None kv_cache_scheme
kv_cache_bits = None and kv_cache_scheme.get("type") == "float"
if isinstance(kv_cache_scheme, dict): and kv_cache_scheme.get("num_bits") == 8
# Handles the expected format: {"type": "float", "num_bits": 8} ):
kv_cache_type = kv_cache_scheme.get("type")
kv_cache_bits = kv_cache_scheme.get("num_bits")
elif isinstance(kv_cache_scheme, str):
# Handles the shorthand format: "FP8"
# We can infer the properties from the string.
if kv_cache_scheme.upper() == "FP8":
kv_cache_type = "float"
kv_cache_bits = 8
# Now, safely use the extracted values in the original logic
if kv_cache_type == "float" and kv_cache_bits == 8:
kv_cache_quant_algo = "FP8" kv_cache_quant_algo = "FP8"
else: else:
kv_cache_quant_algo = "auto" kv_cache_quant_algo = "auto"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment