Unverified Commit 6b3f3f7e authored by Younes B's avatar Younes B Committed by GitHub
Browse files

feat / fix: Properly make use of `subfolder` from HF models (#3072)



* add subfolder

* lint

* change it to empty string

* fix typehints

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 0f63d4f5
...@@ -61,7 +61,7 @@ class HFLM(TemplateLM): ...@@ -61,7 +61,7 @@ class HFLM(TemplateLM):
backend: Literal["default", "causal", "seq2seq"] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: str = "",
tokenizer: Optional[ tokenizer: Optional[
Union[ Union[
str, str,
...@@ -162,14 +162,13 @@ class HFLM(TemplateLM): ...@@ -162,14 +162,13 @@ class HFLM(TemplateLM):
) )
revision = str(revision) # cast to string if not already one revision = str(revision) # cast to string if not already one
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self._get_config( self._get_config(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file, gguf_file=gguf_file,
subfolder=subfolder,
) )
# determine which of 'causal' and 'seq2seq' backends to use for HF models # determine which of 'causal' and 'seq2seq' backends to use for HF models
...@@ -182,6 +181,7 @@ class HFLM(TemplateLM): ...@@ -182,6 +181,7 @@ class HFLM(TemplateLM):
pretrained, pretrained,
tokenizer, tokenizer,
revision=revision, revision=revision,
subfolder=subfolder,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast_tokenizer=use_fast_tokenizer, use_fast_tokenizer=use_fast_tokenizer,
gguf_file=gguf_file, gguf_file=gguf_file,
...@@ -206,6 +206,7 @@ class HFLM(TemplateLM): ...@@ -206,6 +206,7 @@ class HFLM(TemplateLM):
gptqmodel=gptqmodel, gptqmodel=gptqmodel,
gguf_file=gguf_file, gguf_file=gguf_file,
quantization_config=getattr(self.config, "quantization_config", None), quantization_config=getattr(self.config, "quantization_config", None),
subfolder=subfolder,
**kwargs, **kwargs,
) )
...@@ -522,6 +523,7 @@ class HFLM(TemplateLM): ...@@ -522,6 +523,7 @@ class HFLM(TemplateLM):
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
gguf_file: Optional[str] = None, gguf_file: Optional[str] = None,
subfolder: str = "",
) -> None: ) -> None:
"""Return the model config for HuggingFace models""" """Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
...@@ -529,6 +531,7 @@ class HFLM(TemplateLM): ...@@ -529,6 +531,7 @@ class HFLM(TemplateLM):
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file, gguf_file=gguf_file,
subfolder=subfolder,
) )
def _create_model( def _create_model(
...@@ -552,6 +555,7 @@ class HFLM(TemplateLM): ...@@ -552,6 +555,7 @@ class HFLM(TemplateLM):
gptqmodel: Optional[bool] = False, gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None, gguf_file: Optional[str] = None,
quantization_config: Optional[Dict[str, Any]] = None, quantization_config: Optional[Dict[str, Any]] = None,
subfolder: str = "",
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -598,6 +602,7 @@ class HFLM(TemplateLM): ...@@ -598,6 +602,7 @@ class HFLM(TemplateLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file, gguf_file=gguf_file,
quantization_config=quantization_config, quantization_config=quantization_config,
subfolder=subfolder,
**model_kwargs, **model_kwargs,
) )
else: else:
...@@ -697,6 +702,7 @@ class HFLM(TemplateLM): ...@@ -697,6 +702,7 @@ class HFLM(TemplateLM):
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
gguf_file: Optional[str] = None, gguf_file: Optional[str] = None,
add_bos_token: Optional[bool] = False, add_bos_token: Optional[bool] = False,
subfolder: Optional[str] = "",
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -718,6 +724,9 @@ class HFLM(TemplateLM): ...@@ -718,6 +724,9 @@ class HFLM(TemplateLM):
if add_bos_token: if add_bos_token:
kwargs["add_bos_token"] = True kwargs["add_bos_token"] = True
if subfolder:
kwargs["subfolder"] = subfolder
if tokenizer: if tokenizer:
if isinstance(tokenizer, str): if isinstance(tokenizer, str):
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment