"integration-tests/vscode:/vscode.git/clone" did not exist on "bab02ff2bcc1a1aee105cf6bce8ab34f62a8a16f"
Unverified Commit a81334e3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] add an error message when dealing with _best_guess_weight_name ofline (#6184)

* add an error message when dealing with _best_guess_weight_name ofline

* simplify condition
parent d704a730
...@@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors import safetensors
import torch import torch
from huggingface_hub import model_info from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from packaging import version from packaging import version
from torch import nn from torch import nn
...@@ -229,7 +230,9 @@ class LoraLoaderMixin: ...@@ -229,7 +230,9 @@ class LoraLoaderMixin:
# determine `weight_name`. # determine `weight_name`.
if weight_name is None: if weight_name is None:
weight_name = cls._best_guess_weight_name( weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors" pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
) )
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
...@@ -255,7 +258,7 @@ class LoraLoaderMixin: ...@@ -255,7 +258,7 @@ class LoraLoaderMixin:
if model_file is None: if model_file is None:
if weight_name is None: if weight_name is None:
weight_name = cls._best_guess_weight_name( weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin" pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
) )
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
...@@ -294,7 +297,12 @@ class LoraLoaderMixin: ...@@ -294,7 +297,12 @@ class LoraLoaderMixin:
return state_dict, network_alphas return state_dict, network_alphas
@classmethod @classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"): def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = [] targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict): if os.path.isfile(pretrained_model_name_or_path_or_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment