Unverified Commit d366ccc4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[RFC] [Mistral] FP8 format (#10130)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 870c3748
...@@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
mistral_mapping = { mistral_mapping = {
"layers": "model.layers", "layers": "model.layers",
"attention": "self_attn", "attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj", "wq": "q_proj",
"wk": "k_proj", "wk": "k_proj",
"wv": "v_proj", "wv": "v_proj",
...@@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
modules = name.split(".") modules = name.split(".")
# rotary embeds should be sliced # rotary embeds should be sliced
if "wk" in modules: if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads) self.config.num_key_value_heads)
elif "wq" in modules: elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_attention_heads) self.config.num_attention_heads)
for item in modules: num_modules = len(modules)
if item in mapping and mapping[item] not in name: for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = (f"{item}.{next_item}"
if next_item is not None else None)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight
...@@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext): ...@@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer_mode=ctx.model_config.tokenizer_mode) tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder mm_encoder = tokenizer.instruct.mm_encoder
max_image_size = mm_encoder.mm_config.max_image_size image_config = mm_encoder.mm_config if hasattr(
image_patch_size = mm_encoder.mm_config.image_patch_size mm_encoder, "mm_config") else mm_encoder.image_config
max_image_size = image_config.max_image_size
image_patch_size = image_config.image_patch_size
return ((max_image_size // image_patch_size)**2) return ((max_image_size // image_patch_size)**2)
......
...@@ -4,7 +4,7 @@ import enum ...@@ -4,7 +4,7 @@ import enum
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Literal, Optional, Type, Union
import huggingface_hub import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
...@@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], ...@@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
for key, value in elem.items(): for key, value in elem.items():
key = config_mapping.get(key, key) key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value) config_dict[key] = recurse_elems(value)
return PretrainedConfig(**config_dict)
return config_dict
else: else:
return elem return elem
...@@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], ...@@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict["max_position_embeddings"] = config_dict.get( config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000) "max_position_embeddings", 128_000)
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config_dict["quantization_config"] = quantization_config
config_type: Literal["text",
"multimodal"] = "multimodal" if config_dict.get(
"vision_encoder") is not None else "text"
if config_dict.get("moe") is not None: if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"] config_dict["architectures"] = ["MixtralForCausalLM"]
else: else:
config_dict["architectures"] = ["MistralForCausalLM"] config_dict["architectures"] = ["MistralForCausalLM"]
if config_dict.get("vision_encoder") is not None: if config_type == "multimodal":
multimodal_config = config_dict.pop("vision_encoder") multimodal_config = config_dict.pop("vision_encoder")
config_dict = { config_dict = {
...@@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str], ...@@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict.update(kwargs) config_dict.update(kwargs)
config = recurse_elems(config_dict) config_dict = recurse_elems(config_dict)
return config
# transform to HF config format
if config_type == "multimodal":
config_dict["text_config"] = PretrainedConfig(
**config_dict["text_config"])
config_dict["vision_config"] = PretrainedConfig(
**config_dict["vision_config"])
return PretrainedConfig(**config_dict)
def get_hf_image_processor_config( def get_hf_image_processor_config(
......
...@@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: ...@@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
def find_tokenizer_file(files: List[str]): def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") file_pattern = re.compile(
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
matched_files = [file for file in files if file_pattern.match(file)] matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1: if len(matched_files) > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment