Unverified Commit 33357243 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Test composition (#23214)



* Remove nestedness in tool config

* Really do it

* Use remote tools descriptions

* Work

* Clean up eval

* Changes

* Tools

* Tools

* tool

* Fix everything

* Use last result/assign for evaluation

* Prompt

* Remove hardcoded selection

* Evaluation for chat agents

* correct some spelling

* Small fixes

* Change summarization model (#23172)

* Fix link displayed

* Update description of the tool

* Fixes in chat prompt

* Custom tools, custom prompt

* Tool clean up

* save_pretrained and push_to_hub for tool

* Fix init

* Tests

* Fix tests

* Tool save/from_hub/push_to_hub and tool->load_tool

* Clean push_to_hub and add app file

* Custom inference API for endpoints too

* Clean up

* old remote tool and new remote tool

* Make a requirements

* return_code adds tool creation

* Avoid redundancy between global variables

* Remote tools can be loaded

* Tests

* Text summarization tests

* Quality

* Properly mark tests

* Test the python interpreter

* And the CI shall be green.

* fix loading of additional tools

* Work on RemoteTool and fix tests

* General clean up

* Guard imports

* Fix tools

* docs: Fix broken link in 'How to add a model...'  (#23216)

fix link

* Get default endpoint from the Hub

* Add guide

* Simplify tool config

* Docs

* Some fixes

* Docs

* Docs

* Docs

* Fix code returned by agent

* Try this

* Match args with signature in remote tool

* Should fix python interpreter for Python 3.8

* Fix push_to_hub for tools

* Other fixes to push_to_hub

* Add API doc page

* Docs

* Docs

* Custom tools

* Pin tensorflow-probability (#23220)

* Pin tensorflow-probability

* [all-test]

* [all-test] Fix syntax for bash

* PoC for some chaining API

* Text to speech

* J'ai pris des libertés

* Rename

* Basic python interpreter

* Add agents

* Quality

* Add translation tool

* temp

* GenQA + LID + S2T

* Quality + word missing in translation

* Add open assistance, support f-strings in evaluate

* captioning + s2t fixes

* Style

* Refactor descriptions and remove chain

* Support errors and rename OpenAssistantAgent

* Add setup

* Deal with typos + example of inference API

* Some rename + README

* Fixes

* Update prompt

* Unwanted change

* Make sure everyone has a default

* One prompt to rule them all.

* SD

* Description

* Clean up remote tools

* More remote tools

* Add option to return code and update doc

* Image segmentation

* ControlNet

* Gradio demo

* Diffusers protection

* Lib protection

* ControlNet description

* Cleanup

* Style

* Remove accelerate and try to be reproducible

* No randomness

* Male Basic optional in token

* Clean description

* Better prompts

* Fix args eval in interpreter

* Add tool wrapper

* Tool on the Hub

* Style post-rebase

* Big refactor of descriptions, batch generation and evaluation for agents

* Make problems easier - interface to debug

* More problems, add python primitives

* Back to one prompt

* Remove dict for translation

* Be consistent

* Add prompts

* New version of the agent

* Evaluate new agents

* New endpoints agents

* Make all tools a dict variable

* Typo

* Add problems

* Add to big prompt

* Harmonize

* Add tools

* New evaluation

* Add more tools

* Build prompt with tools descriptions

* Tools on the Hub

* Let's chat!

* Cleanup

* Temporary bs4 safeguard

* Cache agents and clean up

* Blank init

* Fix evaluation for agents

* New format for tools on the Hub

* Add method to reset state

* Remove nestedness in tool config

* Really do it

* Use remote tools descriptions

* Work

* Clean up eval

* Changes

* Tools

* Tools

* tool

* Fix everything

* Use last result/assign for evaluation

* Prompt

* Remove hardcoded selection

* Evaluation for chat agents

* correct some spelling

* Small fixes

* Change summarization model (#23172)

* Fix link displayed

* Update description of the tool

* Fixes in chat prompt

* Custom tools, custom prompt

* Tool clean up

* save_pretrained and push_to_hub for tool

* Fix init

* Tests

* Fix tests

* Tool save/from_hub/push_to_hub and tool->load_tool

* Clean push_to_hub and add app file

* Custom inference API for endpoints too

* Clean up

* old remote tool and new remote tool

* Make a requirements

* return_code adds tool creation

* Avoid redundancy between global variables

* Remote tools can be loaded

* Tests

* Text summarization tests

* Quality

* Properly mark tests

* Test the python interpreter

* And the CI shall be green.

* Work on RemoteTool and fix tests

* fix loading of additional tools

* General clean up

* Guard imports

* Fix tools

* Get default endpoint from the Hub

* Simplify tool config

* Add guide

* Docs

* Some fixes

* Docs

* Docs

* Fix code returned by agent

* Try this

* Docs

* Match args with signature in remote tool

* Should fix python interpreter for Python 3.8

* Fix push_to_hub for tools

* Other fixes to push_to_hub

* Add API doc page

* Fixes

* Doc fixes

* Docs

* Fix audio

* Custom tools

* Audio fix

* Improve custom tools docstring

* Docstrings

* Trigger CI

* Mode docstrings

* More docstrings

* Improve custom tools

* Fix for remote tools

* Style

* Fix repo consistency

* Quality

* Tip

* Cleanup on doc

* Cleanup toc

* Add disclaimer for starcoder vs openai

* Remove disclaimer

* Small fixed in the prompts

* 4.29

* Update src/transformers/tools/agents.py
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>

* Complete documentation

* Small fixes

* Agent evaluation

* Note about gradio-tools & LC

* Clean up agents and prompt

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Note about gradio-tools & LC

* Add copyrights and address review comments

* Quality

* Add all language codes

* Add remote tool tests

* Move custom prompts to other docs

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* TTS tests

* Quality

---------
Co-authored-by: default avatarLysandre <hi@lyand.re>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPhilipp Schmid <32632186+philschmid@users.noreply.github.com>
Co-authored-by: default avatarConnor Henderson <connor.henderson@talkiatry.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarLysandre <lysandre@huggingface.co>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 366a8ca0
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
from .base import PipelineTool
class TextClassificationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextClassificationTool
classifier = TextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "facebook/bart-large-mnli"
description = (
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
"It returns the most likely label in the list of provided `labels` for the input text."
)
name = "text_classifier"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
inputs = ["text", ["text"]]
outputs = ["text"]
def setup(self):
super().setup()
config = self.model.config
self.entailment_id = -1
for idx, label in config.id2label.items():
if label.lower().startswith("entail"):
self.entailment_id = int(idx)
if self.entailment_id == -1:
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
def encode(self, text, labels):
self._labels = labels
return self.pre_processor(
[text] * len(labels),
[f"This example is {label}" for label in labels],
return_tensors="pt",
padding="max_length",
)
def decode(self, outputs):
logits = outputs.logits
label_id = torch.argmax(logits[:, 2]).item()
return self._labels[label_id]
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
Can you answer this question about the text: '{question}'"""
class TextQuestionAnsweringTool(PipelineTool):
default_checkpoint = "google/flan-t5-base"
description = (
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
)
name = "text_qa"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text", "text"]
outputs = ["text"]
def encode(self, text: str, question: str):
prompt = QA_PROMPT.format(text=text, question=question)
return self.pre_processor(prompt, return_tensors="pt")
def forward(self, inputs):
output_ids = self.model.generate(**inputs)
in_b, _ = inputs["input_ids"].shape
out_b = output_ids.shape[0]
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
class TextSummarizationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextSummarizationTool
summarizer = TextSummarizationTool()
summarizer(long_text)
```
"""
default_checkpoint = "philschmid/bart-large-cnn-samsum"
description = (
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
"and returns a summary of the text."
)
name = "summarizer"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text"]
outputs = ["text"]
def encode(self, text):
return self.pre_processor(text, return_tensors="pt", truncation=True)
def forward(self, inputs):
return self.model.generate(**inputs)[0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available
from .base import PipelineTool
if is_datasets_available():
from datasets import load_dataset
class TextToSpeechTool(PipelineTool):
default_checkpoint = "microsoft/speecht5_tts"
description = (
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
"text to read (in English) and returns a waveform object containing the sound."
)
name = "text_reader"
pre_processor_class = SpeechT5Processor
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan
inputs = ["text"]
outputs = ["audio"]
def setup(self):
if self.post_processor is None:
self.post_processor = "microsoft/speecht5_hifigan"
super().setup()
def encode(self, text, speaker_embeddings=None):
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
if speaker_embeddings is None:
if not is_datasets_available():
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
def forward(self, inputs):
with torch.no_grad():
return self.model.generate_speech(**inputs)
def decode(self, outputs):
with torch.no_grad():
return self.post_processor(outputs).cpu().detach()
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
LANGUAGE_CODES = {
"Acehnese Arabic": "ace_Arab",
"Acehnese Latin": "ace_Latn",
"Mesopotamian Arabic": "acm_Arab",
"Ta'izzi-Adeni Arabic": "acq_Arab",
"Tunisian Arabic": "aeb_Arab",
"Afrikaans": "afr_Latn",
"South Levantine Arabic": "ajp_Arab",
"Akan": "aka_Latn",
"Amharic": "amh_Ethi",
"North Levantine Arabic": "apc_Arab",
"Modern Standard Arabic": "arb_Arab",
"Modern Standard Arabic Romanized": "arb_Latn",
"Najdi Arabic": "ars_Arab",
"Moroccan Arabic": "ary_Arab",
"Egyptian Arabic": "arz_Arab",
"Assamese": "asm_Beng",
"Asturian": "ast_Latn",
"Awadhi": "awa_Deva",
"Central Aymara": "ayr_Latn",
"South Azerbaijani": "azb_Arab",
"North Azerbaijani": "azj_Latn",
"Bashkir": "bak_Cyrl",
"Bambara": "bam_Latn",
"Balinese": "ban_Latn",
"Belarusian": "bel_Cyrl",
"Bemba": "bem_Latn",
"Bengali": "ben_Beng",
"Bhojpuri": "bho_Deva",
"Banjar Arabic": "bjn_Arab",
"Banjar Latin": "bjn_Latn",
"Standard Tibetan": "bod_Tibt",
"Bosnian": "bos_Latn",
"Buginese": "bug_Latn",
"Bulgarian": "bul_Cyrl",
"Catalan": "cat_Latn",
"Cebuano": "ceb_Latn",
"Czech": "ces_Latn",
"Chokwe": "cjk_Latn",
"Central Kurdish": "ckb_Arab",
"Crimean Tatar": "crh_Latn",
"Welsh": "cym_Latn",
"Danish": "dan_Latn",
"German": "deu_Latn",
"Southwestern Dinka": "dik_Latn",
"Dyula": "dyu_Latn",
"Dzongkha": "dzo_Tibt",
"Greek": "ell_Grek",
"English": "eng_Latn",
"Esperanto": "epo_Latn",
"Estonian": "est_Latn",
"Basque": "eus_Latn",
"Ewe": "ewe_Latn",
"Faroese": "fao_Latn",
"Fijian": "fij_Latn",
"Finnish": "fin_Latn",
"Fon": "fon_Latn",
"French": "fra_Latn",
"Friulian": "fur_Latn",
"Nigerian Fulfulde": "fuv_Latn",
"Scottish Gaelic": "gla_Latn",
"Irish": "gle_Latn",
"Galician": "glg_Latn",
"Guarani": "grn_Latn",
"Gujarati": "guj_Gujr",
"Haitian Creole": "hat_Latn",
"Hausa": "hau_Latn",
"Hebrew": "heb_Hebr",
"Hindi": "hin_Deva",
"Chhattisgarhi": "hne_Deva",
"Croatian": "hrv_Latn",
"Hungarian": "hun_Latn",
"Armenian": "hye_Armn",
"Igbo": "ibo_Latn",
"Ilocano": "ilo_Latn",
"Indonesian": "ind_Latn",
"Icelandic": "isl_Latn",
"Italian": "ita_Latn",
"Javanese": "jav_Latn",
"Japanese": "jpn_Jpan",
"Kabyle": "kab_Latn",
"Jingpho": "kac_Latn",
"Kamba": "kam_Latn",
"Kannada": "kan_Knda",
"Kashmiri Arabic": "kas_Arab",
"Kashmiri Devanagari": "kas_Deva",
"Georgian": "kat_Geor",
"Central Kanuri Arabic": "knc_Arab",
"Central Kanuri Latin": "knc_Latn",
"Kazakh": "kaz_Cyrl",
"Kabiyè": "kbp_Latn",
"Kabuverdianu": "kea_Latn",
"Khmer": "khm_Khmr",
"Kikuyu": "kik_Latn",
"Kinyarwanda": "kin_Latn",
"Kyrgyz": "kir_Cyrl",
"Kimbundu": "kmb_Latn",
"Northern Kurdish": "kmr_Latn",
"Kikongo": "kon_Latn",
"Korean": "kor_Hang",
"Lao": "lao_Laoo",
"Ligurian": "lij_Latn",
"Limburgish": "lim_Latn",
"Lingala": "lin_Latn",
"Lithuanian": "lit_Latn",
"Lombard": "lmo_Latn",
"Latgalian": "ltg_Latn",
"Luxembourgish": "ltz_Latn",
"Luba-Kasai": "lua_Latn",
"Ganda": "lug_Latn",
"Luo": "luo_Latn",
"Mizo": "lus_Latn",
"Standard Latvian": "lvs_Latn",
"Magahi": "mag_Deva",
"Maithili": "mai_Deva",
"Malayalam": "mal_Mlym",
"Marathi": "mar_Deva",
"Minangkabau Arabic ": "min_Arab",
"Minangkabau Latin": "min_Latn",
"Macedonian": "mkd_Cyrl",
"Plateau Malagasy": "plt_Latn",
"Maltese": "mlt_Latn",
"Meitei Bengali": "mni_Beng",
"Halh Mongolian": "khk_Cyrl",
"Mossi": "mos_Latn",
"Maori": "mri_Latn",
"Burmese": "mya_Mymr",
"Dutch": "nld_Latn",
"Norwegian Nynorsk": "nno_Latn",
"Norwegian Bokmål": "nob_Latn",
"Nepali": "npi_Deva",
"Northern Sotho": "nso_Latn",
"Nuer": "nus_Latn",
"Nyanja": "nya_Latn",
"Occitan": "oci_Latn",
"West Central Oromo": "gaz_Latn",
"Odia": "ory_Orya",
"Pangasinan": "pag_Latn",
"Eastern Panjabi": "pan_Guru",
"Papiamento": "pap_Latn",
"Western Persian": "pes_Arab",
"Polish": "pol_Latn",
"Portuguese": "por_Latn",
"Dari": "prs_Arab",
"Southern Pashto": "pbt_Arab",
"Ayacucho Quechua": "quy_Latn",
"Romanian": "ron_Latn",
"Rundi": "run_Latn",
"Russian": "rus_Cyrl",
"Sango": "sag_Latn",
"Sanskrit": "san_Deva",
"Santali": "sat_Olck",
"Sicilian": "scn_Latn",
"Shan": "shn_Mymr",
"Sinhala": "sin_Sinh",
"Slovak": "slk_Latn",
"Slovenian": "slv_Latn",
"Samoan": "smo_Latn",
"Shona": "sna_Latn",
"Sindhi": "snd_Arab",
"Somali": "som_Latn",
"Southern Sotho": "sot_Latn",
"Spanish": "spa_Latn",
"Tosk Albanian": "als_Latn",
"Sardinian": "srd_Latn",
"Serbian": "srp_Cyrl",
"Swati": "ssw_Latn",
"Sundanese": "sun_Latn",
"Swedish": "swe_Latn",
"Swahili": "swh_Latn",
"Silesian": "szl_Latn",
"Tamil": "tam_Taml",
"Tatar": "tat_Cyrl",
"Telugu": "tel_Telu",
"Tajik": "tgk_Cyrl",
"Tagalog": "tgl_Latn",
"Thai": "tha_Thai",
"Tigrinya": "tir_Ethi",
"Tamasheq Latin": "taq_Latn",
"Tamasheq Tifinagh": "taq_Tfng",
"Tok Pisin": "tpi_Latn",
"Tswana": "tsn_Latn",
"Tsonga": "tso_Latn",
"Turkmen": "tuk_Latn",
"Tumbuka": "tum_Latn",
"Turkish": "tur_Latn",
"Twi": "twi_Latn",
"Central Atlas Tamazight": "tzm_Tfng",
"Uyghur": "uig_Arab",
"Ukrainian": "ukr_Cyrl",
"Umbundu": "umb_Latn",
"Urdu": "urd_Arab",
"Northern Uzbek": "uzn_Latn",
"Venetian": "vec_Latn",
"Vietnamese": "vie_Latn",
"Waray": "war_Latn",
"Wolof": "wol_Latn",
"Xhosa": "xho_Latn",
"Eastern Yiddish": "ydd_Hebr",
"Yoruba": "yor_Latn",
"Yue Chinese": "yue_Hant",
"Chinese Simplified": "zho_Hans",
"Chinese Traditional": "zho_Hant",
"Standard Malay": "zsm_Latn",
"Zulu": "zul_Latn",
}
class TranslationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TranslationTool
translator = TranslationTool()
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
```
"""
default_checkpoint = "facebook/nllb-200-distilled-600M"
description = (
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
)
name = "translator"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
lang_to_code = LANGUAGE_CODES
inputs = ["text", "text", "text"]
outputs = ["text"]
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:
raise ValueError(f"{src_lang} is not a supported language.")
if tgt_lang not in self.lang_to_code:
raise ValueError(f"{tgt_lang} is not a supported language.")
src_lang = self.lang_to_code[src_lang]
tgt_lang = self.lang_to_code[tgt_lang]
return self.pre_processor._build_translation_inputs(
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
)
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
......@@ -121,6 +121,7 @@ from .import_utils import (
is_natten_available,
is_ninja_available,
is_onnx_available,
is_openai_available,
is_optimum_available,
is_pandas_available,
is_peft_available,
......
......@@ -235,6 +235,7 @@ def try_to_load_from_cache(
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision if found.
......@@ -251,6 +252,8 @@ def try_to_load_from_cache(
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repo.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
......@@ -266,7 +269,9 @@ def try_to_load_from_cache(
cache_dir = TRANSFORMERS_CACHE
object_id = repo_id.replace("/", "--")
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
if repo_type is None:
repo_type = "model"
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
......@@ -303,6 +308,7 @@ def cached_file(
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
......@@ -342,6 +348,8 @@ def cached_file(
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
......@@ -393,7 +401,7 @@ def cached_file(
if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
......@@ -410,6 +418,7 @@ def cached_file(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
......
......@@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_diffusers_available = importlib.util.find_spec("diffusers") is not None
try:
_diffusers_version = importlib_metadata.version("diffusers")
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
except importlib_metadata.PackageNotFoundError:
_diffusers_available = False
_detectron2_available = importlib.util.find_spec("detectron2") is not None
try:
_detectron2_version = importlib_metadata.version("detectron2")
......@@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_opencv_available = importlib.util.find_spec("cv2") is not None
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
try:
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
......@@ -431,6 +442,10 @@ def is_onnx_available():
return _onnx_available
def is_openai_available():
return importlib.util.find_spec("openai") is not None
def is_flax_available():
return _flax_available
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("document-question-answering")
self.tool.setup()
self.remote_tool = load_tool("document-question-answering", remote=True)
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.tool(image, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_arg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.remote_tool(image, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.tool(image=image, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.remote_tool(image=image, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-captioning")
self.tool.setup()
self.remote_tool = load_tool("image-captioning", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-question-answering")
self.tool.setup()
self.remote_tool = load_tool("image-question-answering", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-segmentation")
self.tool.setup()
self.remote_tool = load_tool("image-segmentation", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.tool(image, "cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.remote_tool(image, "cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.tool(image=image, prompt="cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.remote_tool(image=image, prompt="cat")
self.assertTrue(isinstance(result, Image.Image))
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import CaptureStdout
from transformers.tools.python_interpreter import evaluate
# Fake function we will use as tool
def add_two(x):
return x + 2
class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"
state = {}
result = evaluate(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3})
code = "x = y"
state = {"y": 5}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 5, "y": 5})
def test_evaluate_call(self):
code = "y = add_two(x)"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5})
# Won't work without the tool
with CaptureStdout() as out:
result = evaluate(code, {}, state=state)
assert result is None
assert "tried to execute add_two" in out.out
def test_evaluate_constant(self):
code = "x = 3"
state = {}
result = evaluate(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3})
def test_evaluate_dict(self):
code = "test_dict = {'x': x, 'y': add_two(x)}"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
def test_evaluate_expression(self):
code = "x = 3\ny = 5"
state = {}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5})
def test_evaluate_f_string(self):
code = "text = f'This is x: {x}.'"
state = {"x": 3}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == "This is x: 3."
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
state = {"x": 3}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 2
self.assertDictEqual(state, {"x": 3, "y": 2})
state = {"x": 8}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 8, "y": 5})
def test_evaluate_list(self):
code = "test_list = [x, add_two(x)]"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
self.assertListEqual(result, [3, 5])
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
def test_evaluate_name(self):
code = "y = x"
state = {"x": 3}
result = evaluate(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "y": 3})
def test_evaluate_subscript(self):
code = "test_list = [x, add_two(x)]\ntest_list[1]"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available, load_tool
from .test_tools_common import ToolTesterMixin
if is_torch_available():
import torch
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("speech-to-text")
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(torch.ones(3000))
self.assertEqual(result, " you")
def test_exact_match_kwarg(self):
result = self.tool(audio=torch.ones(3000))
self.assertEqual(result, " you")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-classification")
self.tool.setup()
self.remote_tool = load_tool("text-classification", remote=True)
def test_exact_match_arg(self):
result = self.tool("That's quite cool", ["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_arg_remote(self):
result = self.remote_tool("That's quite cool", ["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_kwarg(self):
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
self.assertEqual(result, "positive")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
TEXT = """
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
"""
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-question-answering")
self.tool.setup()
self.remote_tool = load_tool("text-question-answering", remote=True)
def test_exact_match_arg(self):
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_arg_remote(self):
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_kwarg(self):
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
TEXT = """
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
"""
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("summarization")
self.tool.setup()
self.remote_tool = load_tool("summarization", remote=True)
def test_exact_match_arg(self):
result = self.tool(TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_arg_remote(self):
result = self.remote_tool(TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_kwarg(self):
result = self.tool(text=TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text=TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from transformers.utils import is_torch_available
if is_torch_available():
import torch
from transformers.testing_utils import require_torch
from .test_tools_common import ToolTesterMixin
@require_torch
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-to-speech")
self.tool.setup()
def test_exact_match_arg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
self.assertTrue(
torch.allclose(
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
)
)
def test_exact_match_kwarg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
self.assertTrue(
torch.allclose(
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
)
)
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import List
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir, is_tool_test
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
authorized_types = ["text", "image", "audio"]
def create_inputs(input_types: List[str]):
inputs = []
for input_type in input_types:
if input_type == "text":
inputs.append("Text input")
elif input_type == "image":
inputs.append(
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
)
elif input_type == "audio":
inputs.append(torch.ones(3000))
elif isinstance(input_type, list):
inputs.append(create_inputs(input_type))
else:
raise ValueError(f"Invalid type requested: {input_type}")
return inputs
def output_types(outputs: List):
output_types = []
for output in outputs:
if isinstance(output, str):
output_types.append("text")
elif isinstance(output, Image.Image):
output_types.append("image")
elif isinstance(output, torch.Tensor):
output_types.append("audio")
else:
raise ValueError(f"Invalid output: {output}")
return output_types
@is_tool_test
class ToolTesterMixin:
def test_inputs_outputs(self):
self.assertTrue(hasattr(self.tool, "inputs"))
self.assertTrue(hasattr(self.tool, "outputs"))
inputs = self.tool.inputs
for _input in inputs:
if isinstance(_input, list):
for __input in _input:
self.assertTrue(__input in authorized_types)
else:
self.assertTrue(_input in authorized_types)
outputs = self.tool.outputs
for _output in outputs:
self.assertTrue(_output in authorized_types)
def test_call(self):
inputs = create_inputs(self.tool.inputs)
outputs = self.tool(*inputs)
# There is a single output
if len(self.tool.outputs) == 1:
outputs = [outputs]
self.assertListEqual(output_types(outputs), self.tool.outputs)
def test_common_attributes(self):
self.assertTrue(hasattr(self.tool, "description"))
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
self.assertTrue(self.tool.description.startswith("This is a tool that"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment