Unverified Commit 441e6ac1 authored by Sam Passaglia's avatar Sam Passaglia Committed by GitHub
Browse files

changed quantized type signature and default val (#532)

parent 095d8406
...@@ -70,7 +70,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -70,7 +70,7 @@ class HuggingFaceAutoLM(BaseLM):
def __init__( def __init__(
self, self,
pretrained: str, pretrained: str,
quantized: Optional[Union[True, str]] = None, quantized: Optional[Union[bool, str]] = False,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
revision: Optional[str] = "main", revision: Optional[str] = "main",
...@@ -96,7 +96,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -96,7 +96,7 @@ class HuggingFaceAutoLM(BaseLM):
The HuggingFace Hub model ID name or the path to a pre-trained The HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path` model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API. argument of `from_pretrained` in the HuggingFace `transformers` API.
quantized (str or True, optional, defaults to None): quantized (str or bool, optional, defaults to False):
File name of a GPTQ quantized model to load. Set to `True` to use the File name of a GPTQ quantized model to load. Set to `True` to use the
default name of the quantized model. default name of the quantized model.
add_special_tokens (bool, optional, defaults to True): add_special_tokens (bool, optional, defaults to True):
...@@ -234,7 +234,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -234,7 +234,7 @@ class HuggingFaceAutoLM(BaseLM):
self, self,
*, *,
pretrained: str, pretrained: str,
quantized: Optional[Union[True, str]] = None, quantized: Optional[Union[bool, str]] = False,
revision: str, revision: str,
subfolder: str, subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None, device_map: Optional[Union[str, _DeviceMapping]] = None,
...@@ -246,7 +246,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -246,7 +246,7 @@ class HuggingFaceAutoLM(BaseLM):
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
) -> transformers.AutoModel: ) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration.""" """Returns a pre-trained pytorch model from a pre-trained model configuration."""
if quantized is None: if not quantized:
model = self.AUTO_MODEL_CLASS.from_pretrained( model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment