Unverified Commit 5eeabc2a authored by Tristan Leclercq's avatar Tristan Leclercq Committed by GitHub
Browse files

[Bugfix] Fix bnb quantization for models with both HF-format and Mistral-format weights (#14950)

parent 18551e82
...@@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test ...@@ -15,6 +15,8 @@ from ..utils import compare_two_settings, create_new_process_for_each_test
models_4bit_to_test = [ models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"), ("facebook/opt-125m", "quantize opt model inflight"),
("mistralai/Mistral-7B-Instruct-v0.3",
"quantize inflight model with both HF and Mistral format weights")
] ]
models_pre_qaunt_4bit_to_test = [ models_pre_qaunt_4bit_to_test = [
......
...@@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path: str, model_name_or_path: str,
allowed_patterns: List[str], allowed_patterns: List[str],
revision: Optional[str] = None, revision: Optional[str] = None,
) -> Tuple[List[str], str]: ) -> Tuple[str, List[str], str]:
"""Retrieve weight files. Download the files if necessary. """Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern.""" Return the weight files and the file pattern."""
...@@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_files = glob.glob( weight_files = glob.glob(
os.path.join(model_name_or_path, pattern)) os.path.join(model_name_or_path, pattern))
if weight_files: if weight_files:
return weight_files, pattern return model_name_or_path, weight_files, pattern
else: else:
hf_api = HfApi() hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
...@@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision, revision,
ignore_patterns=self.load_config.ignore_patterns, ignore_patterns=self.load_config.ignore_patterns,
) )
return glob.glob(os.path.join(hf_folder, pattern)), pattern return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
raise RuntimeError( raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`") f"No model weights found in: `{model_name_or_path}`")
...@@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files( hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision) model_name_or_path, allowed_patterns, revision)
if matched_pattern != "*.safetensors": use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files) hf_weights_files)
...@@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise RuntimeError( raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`") f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, matched_pattern == "*.safetensors" return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors: if use_safetensors:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment