Unverified Commit b8ab989f authored by lukec's avatar lukec Committed by GitHub
Browse files

Fix the FP8 E4M3 parsing offline scales failure bug (#3045)

parent b3393e94
...@@ -27,6 +27,7 @@ import huggingface_hub.constants ...@@ -27,6 +27,7 @@ import huggingface_hub.constants
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: ...@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return name return name
# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!"
)
return self
@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}."
)
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}."
)
for i in range(tp_size):
assert (
i in self.scaling_factor
), f"KV cache scales map for TP rank {i} not found."
return self
@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}."
)
return self
class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema
@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!"
)
return self
def kv_cache_scales_loader( def kv_cache_scales_loader(
filename: str, filename: str,
tp_rank: int, tp_rank: int,
...@@ -681,7 +757,7 @@ def kv_cache_scales_loader( ...@@ -681,7 +757,7 @@ def kv_cache_scales_loader(
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename) logger.error("Error decoding JSON in file '%s'.", filename)
except Exception: except Exception:
logger.exception("An error occurred while reading '%s'.", filename) logger.error("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit # This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded # Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales # which ultimately defaults to 1.0 scales
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment