"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1e51bb717c04ca4b01a05a7a548e6b550be38628"
Unverified Commit 3fd7eee1 authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

Adds use_auth_token with pipelines (#11123)

* added model_kwargs to infer_framework_from_model

* added model_kwargs to tokenizer

* added use_auth_token as named parameter

* added dynamic get for use_auth_token
parent 1c151283
...@@ -246,6 +246,7 @@ def pipeline( ...@@ -246,6 +246,7 @@ def pipeline(
framework: Optional[str] = None, framework: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
use_fast: bool = True, use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Dict[str, Any] = {}, model_kwargs: Dict[str, Any] = {},
**kwargs **kwargs
) -> Pipeline: ) -> Pipeline:
...@@ -308,6 +309,10 @@ def pipeline( ...@@ -308,6 +309,10 @@ def pipeline(
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`): use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`). Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
model_kwargs: model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(..., Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function. **model_kwargs)` function.
...@@ -367,6 +372,9 @@ def pipeline( ...@@ -367,6 +372,9 @@ def pipeline(
task_class, model_class = targeted_task["impl"], targeted_task[framework] task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Instantiate tokenizer if needed # Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)): if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple): if isinstance(tokenizer, tuple):
...@@ -377,12 +385,12 @@ def pipeline( ...@@ -377,12 +385,12 @@ def pipeline(
) )
else: else:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs
) )
# Instantiate config if needed # Instantiate config if needed
if isinstance(config, str): if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task) config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
# Instantiate modelcard if needed # Instantiate modelcard if needed
if isinstance(modelcard, str): if isinstance(modelcard, str):
......
...@@ -48,7 +48,7 @@ logger = logging.get_logger(__name__) ...@@ -48,7 +48,7 @@ logger = logging.get_logger(__name__)
def infer_framework_from_model( def infer_framework_from_model(
model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None, task: Optional[str] = None model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
): ):
""" """
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
...@@ -65,10 +65,11 @@ def infer_framework_from_model( ...@@ -65,10 +65,11 @@ def infer_framework_from_model(
from. from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class. A mapping framework to class.
revision (:obj:`str`, `optional`): task (:obj:`str`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The task defining which pipeline will be returned.
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any model_kwargs:
identifier allowed by git. Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
Returns: Returns:
:obj:`Tuple`: A tuple framework, model. :obj:`Tuple`: A tuple framework, model.
...@@ -80,19 +81,20 @@ def infer_framework_from_model( ...@@ -80,19 +81,20 @@ def infer_framework_from_model(
"To install PyTorch, read the instructions at https://pytorch.org/." "To install PyTorch, read the instructions at https://pytorch.org/."
) )
if isinstance(model, str): if isinstance(model, str):
model_kwargs["_from_pipeline"] = task
if is_torch_available() and not is_tf_available(): if is_torch_available() and not is_tf_available():
model_class = model_classes.get("pt", AutoModel) model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) model = model_class.from_pretrained(model, **model_kwargs)
elif is_tf_available() and not is_torch_available(): elif is_tf_available() and not is_torch_available():
model_class = model_classes.get("tf", TFAutoModel) model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) model = model_class.from_pretrained(model, **model_kwargs)
else: else:
try: try:
model_class = model_classes.get("pt", AutoModel) model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) model = model_class.from_pretrained(model, **model_kwargs)
except OSError: except OSError:
model_class = model_classes.get("tf", TFAutoModel) model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task) model = model_class.from_pretrained(model, **model_kwargs)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model return framework, model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment