Unverified Commit 651408a0 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Styling`] stylify using ruff (#27144)



* try to stylify using ruff

* might need to remove these changes?

* use ruf format andruff check

* use isinstance instead of type comparision

* use # fmt: skip

* use # fmt: skip

* nits

* soem styling changes

* update ci job

* nits isinstance

* more files update

* nits

* more nits

* small nits

* check and format

* revert wrong changes

* actually use formatter instead of checker

* nits

* well docbuilder is overwriting this commit

* revert notebook changes

* try to nuke docbuilder

* style

* fix feature exrtaction test

* remve `indent-width = 4`

* fixup

* more nits

* update the ruff version that we use

* style

* nuke docbuilder styling

* leve the print for detected changes

* nits

* Remove file I/O
Co-authored-by: default avatarcharliermarsh <charlie.r.marsh@gmail.com>

* style

* nits

* revert notebook changes

* Add # fmt skip when possible

* Add # fmt skip when possible

...
parent acb5b4af
[tool.black]
line-length = 119
target-version = ['py37']
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741"]
ignore = ["C901", "E501", "E741", "F402", "F823" ]
select = ["C", "E", "F", "I", "W"]
line-length = 119
......@@ -18,6 +14,19 @@ line-length = 119
lines-after-imports = 2
known-first-party = ["transformers"]
[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
[tool.pytest.ini_options]
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
doctest_glob="**/*.md"
......
from collections import Counter
import datasets
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging
logging.set_verbosity_info()
TOKENIZER_CLASSES = {
......@@ -101,8 +103,8 @@ def check_details(line, spm_ids, tok_ids, slow, fast):
except Exception:
pass
ok_start = fast.decode(spm_ids[:first])
ok_end = fast.decode(spm_ids[last:])
fast.decode(spm_ids[:first])
fast.decode(spm_ids[last:])
wrong = fast.decode(spm_ids[first:last])
print()
print(wrong)
......
......@@ -24,18 +24,19 @@
#
# It will be used then as "stas/tiny-wmt19-en-ru"
from pathlib import Path
import json
import tempfile
from pathlib import Path
from transformers import FSMTTokenizer, FSMTConfig, FSMTForConditionalGeneration
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTTokenizer
from transformers.models.fsmt.tokenization_fsmt import VOCAB_FILES_NAMES
mname_tiny = "tiny-wmt19-en-ru"
# Build
# borrowed from a test
# borrowed from a test
vocab = [ "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "w</w>", "r</w>", "t</w>", "lo", "low", "er</w>", "low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>", ]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
......@@ -57,7 +58,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
tgt_vocab_file=tgt_vocab_file,
merges_file=merges_file,
)
config = FSMTConfig(
langs=['ru', 'en'],
src_vocab_size=1000, tgt_vocab_size=1000,
......
......@@ -27,16 +27,18 @@
# It will be used then as "stas/tiny-wmt19-en-de"
# Build
from transformers import FSMTTokenizer, FSMTConfig, FSMTForConditionalGeneration
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTTokenizer
mname = "facebook/wmt19-en-de"
tokenizer = FSMTTokenizer.from_pretrained(mname)
# get the correct vocab sizes, etc. from the master model
config = FSMTConfig.from_pretrained(mname)
config.update(dict(
d_model=4,
encoder_layers=1, decoder_layers=1,
encoder_ffn_dim=4, decoder_ffn_dim=4,
encoder_attention_heads=1, decoder_attention_heads=1))
config.update({
"d_model": 4,
"encoder_layers": 1, "decoder_layers": 1,
"encoder_ffn_dim": 4, "decoder_ffn_dim": 4,
"encoder_attention_heads": 1, "decoder_attention_heads": 1})
tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}")
......
......@@ -19,6 +19,7 @@
import os
from pathlib import Path
def write_model_card(model_card_dir, src_lang, tgt_lang, model_name):
texts = {
......
......@@ -19,6 +19,7 @@
import os
from pathlib import Path
def write_model_card(model_card_dir, src_lang, tgt_lang, model_name):
texts = {
......
......@@ -19,6 +19,7 @@
import os
from pathlib import Path
def write_model_card(model_card_dir, src_lang, tgt_lang):
texts = {
......@@ -39,7 +40,7 @@ def write_model_card(model_card_dir, src_lang, tgt_lang):
readme = f"""
---
language:
language:
- {src_lang}
- {tgt_lang}
thumbnail:
......
......@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# this script builds a small sample spm file tests/fixtures/test_sentencepiece_no_bos.model, with features needed by pegasus
# this script builds a small sample spm file tests/fixtures/test_sentencepiece_no_bos.model, with features needed by pegasus
# 1. pip install sentencepiece
#
#
# 2. wget https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt
# 3. build
import sentencepiece as spm
# pegasus:
# 1. no bos
# 2. eos_id is 1
......
......@@ -15,8 +15,8 @@
Script to close stale issue. Taken in part from the AllenNLP repository.
https://github.com/allenai/allennlp.
"""
from datetime import datetime as dt
import os
from datetime import datetime as dt
import github.GithubException
from github import Github
......@@ -39,7 +39,7 @@ def main():
for i, issue in enumerate(open_issues):
print(i, issue)
comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
comments = sorted(list(issue.get_comments()), key=lambda i: i.created_at, reverse=True)
last_comment = comments[0] if len(comments) > 0 else None
if (
last_comment is not None and last_comment.user.login == "github-actions[bot]"
......
......@@ -99,7 +99,6 @@ _deps = [
"accelerate>=0.20.3",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"black~=23.1",
"codecarbon==1.2.0",
"cookiecutter==1.7.3",
"dataclasses",
......@@ -156,7 +155,7 @@ _deps = [
"rhoknp>=1.1.0,<1.3.1",
"rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff>=0.0.241,<=0.0.259",
"ruff>=0.1.5,<=0.2",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.3.1",
......@@ -310,7 +309,7 @@ extras["testing"] = (
"dill",
"evaluate",
"pytest-timeout",
"black",
"ruff",
"sacrebleu",
"rouge-score",
"nltk",
......@@ -329,7 +328,7 @@ extras["testing"] = (
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"] + extras["sentencepiece"]
extras["quality"] = deps_list("black", "datasets", "isort", "ruff", "GitPython", "hf-doc-builder", "urllib3")
extras["quality"] = deps_list("datasets", "isort", "ruff", "GitPython", "hf-doc-builder", "urllib3")
extras["all"] = (
extras["tf"]
......
......@@ -246,6 +246,7 @@ class PretrainedConfig(PushToHubMixin):
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
"""
model_type: str = ""
is_composition: bool = False
attribute_map: Dict[str, str] = {}
......
......@@ -724,9 +724,7 @@ class MBart50Converter(SpmConverter):
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
# fmt: off
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
# fmt: on
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
vocab += [("<mask>", 0.0)]
return vocab
......@@ -753,11 +751,7 @@ class NllbConverter(SpmConverter):
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [
# fmt: off
('ace_Arab', 0.0), ('ace_Latn', 0.0), ('acm_Arab', 0.0), ('acq_Arab', 0.0), ('aeb_Arab', 0.0), ('afr_Latn', 0.0), ('ajp_Arab', 0.0), ('aka_Latn', 0.0), ('amh_Ethi', 0.0), ('apc_Arab', 0.0), ('arb_Arab', 0.0), ('ars_Arab', 0.0), ('ary_Arab', 0.0), ('arz_Arab', 0.0), ('asm_Beng', 0.0), ('ast_Latn', 0.0), ('awa_Deva', 0.0), ('ayr_Latn', 0.0), ('azb_Arab', 0.0), ('azj_Latn', 0.0), ('bak_Cyrl', 0.0), ('bam_Latn', 0.0), ('ban_Latn', 0.0), ('bel_Cyrl', 0.0), ('bem_Latn', 0.0), ('ben_Beng', 0.0), ('bho_Deva', 0.0), ('bjn_Arab', 0.0), ('bjn_Latn', 0.0), ('bod_Tibt', 0.0), ('bos_Latn', 0.0), ('bug_Latn', 0.0), ('bul_Cyrl', 0.0), ('cat_Latn', 0.0), ('ceb_Latn', 0.0), ('ces_Latn', 0.0), ('cjk_Latn', 0.0), ('ckb_Arab', 0.0), ('crh_Latn', 0.0), ('cym_Latn', 0.0), ('dan_Latn', 0.0), ('deu_Latn', 0.0), ('dik_Latn', 0.0), ('dyu_Latn', 0.0), ('dzo_Tibt', 0.0), ('ell_Grek', 0.0), ('eng_Latn', 0.0), ('epo_Latn', 0.0), ('est_Latn', 0.0), ('eus_Latn', 0.0), ('ewe_Latn', 0.0), ('fao_Latn', 0.0), ('pes_Arab', 0.0), ('fij_Latn', 0.0), ('fin_Latn', 0.0), ('fon_Latn', 0.0), ('fra_Latn', 0.0), ('fur_Latn', 0.0), ('fuv_Latn', 0.0), ('gla_Latn', 0.0), ('gle_Latn', 0.0), ('glg_Latn', 0.0), ('grn_Latn', 0.0), ('guj_Gujr', 0.0), ('hat_Latn', 0.0), ('hau_Latn', 0.0), ('heb_Hebr', 0.0), ('hin_Deva', 0.0), ('hne_Deva', 0.0), ('hrv_Latn', 0.0), ('hun_Latn', 0.0), ('hye_Armn', 0.0), ('ibo_Latn', 0.0), ('ilo_Latn', 0.0), ('ind_Latn', 0.0), ('isl_Latn', 0.0), ('ita_Latn', 0.0), ('jav_Latn', 0.0), ('jpn_Jpan', 0.0), ('kab_Latn', 0.0), ('kac_Latn', 0.0), ('kam_Latn', 0.0), ('kan_Knda', 0.0), ('kas_Arab', 0.0), ('kas_Deva', 0.0), ('kat_Geor', 0.0), ('knc_Arab', 0.0), ('knc_Latn', 0.0), ('kaz_Cyrl', 0.0), ('kbp_Latn', 0.0), ('kea_Latn', 0.0), ('khm_Khmr', 0.0), ('kik_Latn', 0.0), ('kin_Latn', 0.0), ('kir_Cyrl', 0.0), ('kmb_Latn', 0.0), ('kon_Latn', 0.0), ('kor_Hang', 0.0), ('kmr_Latn', 0.0), ('lao_Laoo', 0.0), ('lvs_Latn', 0.0), ('lij_Latn', 0.0), ('lim_Latn', 0.0), ('lin_Latn', 0.0), ('lit_Latn', 0.0), ('lmo_Latn', 0.0), ('ltg_Latn', 0.0), ('ltz_Latn', 0.0), ('lua_Latn', 0.0), ('lug_Latn', 0.0), ('luo_Latn', 0.0), ('lus_Latn', 0.0), ('mag_Deva', 0.0), ('mai_Deva', 0.0), ('mal_Mlym', 0.0), ('mar_Deva', 0.0), ('min_Latn', 0.0), ('mkd_Cyrl', 0.0), ('plt_Latn', 0.0), ('mlt_Latn', 0.0), ('mni_Beng', 0.0), ('khk_Cyrl', 0.0), ('mos_Latn', 0.0), ('mri_Latn', 0.0), ('zsm_Latn', 0.0), ('mya_Mymr', 0.0), ('nld_Latn', 0.0), ('nno_Latn', 0.0), ('nob_Latn', 0.0), ('npi_Deva', 0.0), ('nso_Latn', 0.0), ('nus_Latn', 0.0), ('nya_Latn', 0.0), ('oci_Latn', 0.0), ('gaz_Latn', 0.0), ('ory_Orya', 0.0), ('pag_Latn', 0.0), ('pan_Guru', 0.0), ('pap_Latn', 0.0), ('pol_Latn', 0.0), ('por_Latn', 0.0), ('prs_Arab', 0.0), ('pbt_Arab', 0.0), ('quy_Latn', 0.0), ('ron_Latn', 0.0), ('run_Latn', 0.0), ('rus_Cyrl', 0.0), ('sag_Latn', 0.0), ('san_Deva', 0.0), ('sat_Beng', 0.0), ('scn_Latn', 0.0), ('shn_Mymr', 0.0), ('sin_Sinh', 0.0), ('slk_Latn', 0.0), ('slv_Latn', 0.0), ('smo_Latn', 0.0), ('sna_Latn', 0.0), ('snd_Arab', 0.0), ('som_Latn', 0.0), ('sot_Latn', 0.0), ('spa_Latn', 0.0), ('als_Latn', 0.0), ('srd_Latn', 0.0), ('srp_Cyrl', 0.0), ('ssw_Latn', 0.0), ('sun_Latn', 0.0), ('swe_Latn', 0.0), ('swh_Latn', 0.0), ('szl_Latn', 0.0), ('tam_Taml', 0.0), ('tat_Cyrl', 0.0), ('tel_Telu', 0.0), ('tgk_Cyrl', 0.0), ('tgl_Latn', 0.0), ('tha_Thai', 0.0), ('tir_Ethi', 0.0), ('taq_Latn', 0.0), ('taq_Tfng', 0.0), ('tpi_Latn', 0.0), ('tsn_Latn', 0.0), ('tso_Latn', 0.0), ('tuk_Latn', 0.0), ('tum_Latn', 0.0), ('tur_Latn', 0.0), ('twi_Latn', 0.0), ('tzm_Tfng', 0.0), ('uig_Arab', 0.0), ('ukr_Cyrl', 0.0), ('umb_Latn', 0.0), ('urd_Arab', 0.0), ('uzn_Latn', 0.0), ('vec_Latn', 0.0), ('vie_Latn', 0.0), ('war_Latn', 0.0), ('wol_Latn', 0.0), ('xho_Latn', 0.0), ('ydd_Hebr', 0.0), ('yor_Latn', 0.0), ('yue_Hant', 0.0), ('zho_Hans', 0.0), ('zho_Hant', 0.0), ('zul_Latn', 0.0)
# fmt: on
]
vocab += [('ace_Arab', 0.0), ('ace_Latn', 0.0), ('acm_Arab', 0.0), ('acq_Arab', 0.0), ('aeb_Arab', 0.0), ('afr_Latn', 0.0), ('ajp_Arab', 0.0), ('aka_Latn', 0.0), ('amh_Ethi', 0.0), ('apc_Arab', 0.0), ('arb_Arab', 0.0), ('ars_Arab', 0.0), ('ary_Arab', 0.0), ('arz_Arab', 0.0), ('asm_Beng', 0.0), ('ast_Latn', 0.0), ('awa_Deva', 0.0), ('ayr_Latn', 0.0), ('azb_Arab', 0.0), ('azj_Latn', 0.0), ('bak_Cyrl', 0.0), ('bam_Latn', 0.0), ('ban_Latn', 0.0), ('bel_Cyrl', 0.0), ('bem_Latn', 0.0), ('ben_Beng', 0.0), ('bho_Deva', 0.0), ('bjn_Arab', 0.0), ('bjn_Latn', 0.0), ('bod_Tibt', 0.0), ('bos_Latn', 0.0), ('bug_Latn', 0.0), ('bul_Cyrl', 0.0), ('cat_Latn', 0.0), ('ceb_Latn', 0.0), ('ces_Latn', 0.0), ('cjk_Latn', 0.0), ('ckb_Arab', 0.0), ('crh_Latn', 0.0), ('cym_Latn', 0.0), ('dan_Latn', 0.0), ('deu_Latn', 0.0), ('dik_Latn', 0.0), ('dyu_Latn', 0.0), ('dzo_Tibt', 0.0), ('ell_Grek', 0.0), ('eng_Latn', 0.0), ('epo_Latn', 0.0), ('est_Latn', 0.0), ('eus_Latn', 0.0), ('ewe_Latn', 0.0), ('fao_Latn', 0.0), ('pes_Arab', 0.0), ('fij_Latn', 0.0), ('fin_Latn', 0.0), ('fon_Latn', 0.0), ('fra_Latn', 0.0), ('fur_Latn', 0.0), ('fuv_Latn', 0.0), ('gla_Latn', 0.0), ('gle_Latn', 0.0), ('glg_Latn', 0.0), ('grn_Latn', 0.0), ('guj_Gujr', 0.0), ('hat_Latn', 0.0), ('hau_Latn', 0.0), ('heb_Hebr', 0.0), ('hin_Deva', 0.0), ('hne_Deva', 0.0), ('hrv_Latn', 0.0), ('hun_Latn', 0.0), ('hye_Armn', 0.0), ('ibo_Latn', 0.0), ('ilo_Latn', 0.0), ('ind_Latn', 0.0), ('isl_Latn', 0.0), ('ita_Latn', 0.0), ('jav_Latn', 0.0), ('jpn_Jpan', 0.0), ('kab_Latn', 0.0), ('kac_Latn', 0.0), ('kam_Latn', 0.0), ('kan_Knda', 0.0), ('kas_Arab', 0.0), ('kas_Deva', 0.0), ('kat_Geor', 0.0), ('knc_Arab', 0.0), ('knc_Latn', 0.0), ('kaz_Cyrl', 0.0), ('kbp_Latn', 0.0), ('kea_Latn', 0.0), ('khm_Khmr', 0.0), ('kik_Latn', 0.0), ('kin_Latn', 0.0), ('kir_Cyrl', 0.0), ('kmb_Latn', 0.0), ('kon_Latn', 0.0), ('kor_Hang', 0.0), ('kmr_Latn', 0.0), ('lao_Laoo', 0.0), ('lvs_Latn', 0.0), ('lij_Latn', 0.0), ('lim_Latn', 0.0), ('lin_Latn', 0.0), ('lit_Latn', 0.0), ('lmo_Latn', 0.0), ('ltg_Latn', 0.0), ('ltz_Latn', 0.0), ('lua_Latn', 0.0), ('lug_Latn', 0.0), ('luo_Latn', 0.0), ('lus_Latn', 0.0), ('mag_Deva', 0.0), ('mai_Deva', 0.0), ('mal_Mlym', 0.0), ('mar_Deva', 0.0), ('min_Latn', 0.0), ('mkd_Cyrl', 0.0), ('plt_Latn', 0.0), ('mlt_Latn', 0.0), ('mni_Beng', 0.0), ('khk_Cyrl', 0.0), ('mos_Latn', 0.0), ('mri_Latn', 0.0), ('zsm_Latn', 0.0), ('mya_Mymr', 0.0), ('nld_Latn', 0.0), ('nno_Latn', 0.0), ('nob_Latn', 0.0), ('npi_Deva', 0.0), ('nso_Latn', 0.0), ('nus_Latn', 0.0), ('nya_Latn', 0.0), ('oci_Latn', 0.0), ('gaz_Latn', 0.0), ('ory_Orya', 0.0), ('pag_Latn', 0.0), ('pan_Guru', 0.0), ('pap_Latn', 0.0), ('pol_Latn', 0.0), ('por_Latn', 0.0), ('prs_Arab', 0.0), ('pbt_Arab', 0.0), ('quy_Latn', 0.0), ('ron_Latn', 0.0), ('run_Latn', 0.0), ('rus_Cyrl', 0.0), ('sag_Latn', 0.0), ('san_Deva', 0.0), ('sat_Beng', 0.0), ('scn_Latn', 0.0), ('shn_Mymr', 0.0), ('sin_Sinh', 0.0), ('slk_Latn', 0.0), ('slv_Latn', 0.0), ('smo_Latn', 0.0), ('sna_Latn', 0.0), ('snd_Arab', 0.0), ('som_Latn', 0.0), ('sot_Latn', 0.0), ('spa_Latn', 0.0), ('als_Latn', 0.0), ('srd_Latn', 0.0), ('srp_Cyrl', 0.0), ('ssw_Latn', 0.0), ('sun_Latn', 0.0), ('swe_Latn', 0.0), ('swh_Latn', 0.0), ('szl_Latn', 0.0), ('tam_Taml', 0.0), ('tat_Cyrl', 0.0), ('tel_Telu', 0.0), ('tgk_Cyrl', 0.0), ('tgl_Latn', 0.0), ('tha_Thai', 0.0), ('tir_Ethi', 0.0), ('taq_Latn', 0.0), ('taq_Tfng', 0.0), ('tpi_Latn', 0.0), ('tsn_Latn', 0.0), ('tso_Latn', 0.0), ('tuk_Latn', 0.0), ('tum_Latn', 0.0), ('tur_Latn', 0.0), ('twi_Latn', 0.0), ('tzm_Tfng', 0.0), ('uig_Arab', 0.0), ('ukr_Cyrl', 0.0), ('umb_Latn', 0.0), ('urd_Arab', 0.0), ('uzn_Latn', 0.0), ('vec_Latn', 0.0), ('vie_Latn', 0.0), ('war_Latn', 0.0), ('wol_Latn', 0.0), ('xho_Latn', 0.0), ('ydd_Hebr', 0.0), ('yor_Latn', 0.0), ('yue_Hant', 0.0), ('zho_Hans', 0.0), ('zho_Hant', 0.0), ('zul_Latn', 0.0)] # fmt: skip
vocab += [("<mask>", 0.0)]
return vocab
......@@ -1128,9 +1122,7 @@ class XGLMConverter(SpmConverter):
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
# fmt: off
vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)]
# fmt: on
vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
return vocab
def unk_id(self, proto):
......
......@@ -121,7 +121,7 @@ def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else:
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
# Handling of all other possible keys.
......@@ -196,7 +196,7 @@ def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any
if isinstance(first["label_ids"], np.ndarray):
batch["labels"] = np.stack([f["label_ids"] for f in features])
else:
dtype = np.int64 if type(first["label_ids"][0]) is int else np.float32
dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
# Handling of all other possible keys.
......
......@@ -6,7 +6,6 @@ deps = {
"accelerate": "accelerate>=0.20.3",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"black": "black~=23.1",
"codecarbon": "codecarbon==1.2.0",
"cookiecutter": "cookiecutter==1.7.3",
"dataclasses": "dataclasses",
......@@ -62,7 +61,7 @@ deps = {
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
"rjieba": "rjieba",
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff": "ruff>=0.0.241,<=0.0.259",
"ruff": "ruff>=0.1.5,<=0.2",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.3.1",
......
......@@ -245,8 +245,7 @@ def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]
and isinstance(annotation["annotations"], (list, tuple))
and (
# an image can have no annotations
len(annotation["annotations"]) == 0
or isinstance(annotation["annotations"][0], dict)
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
)
):
return True
......@@ -262,8 +261,7 @@ def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]])
and isinstance(annotation["segments_info"], (list, tuple))
and (
# an image can have no segments
len(annotation["segments_info"]) == 0
or isinstance(annotation["segments_info"][0], dict)
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
)
):
return True
......
......@@ -179,6 +179,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
......
......@@ -1075,6 +1075,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
......@@ -3242,6 +3243,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
"""
# TODO (joao): flagged for delection due to embeddings refactor
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
......
......@@ -1095,6 +1095,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
......
......@@ -97,6 +97,7 @@ class AlignTextConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "align_text_model"
def __init__(
......
......@@ -100,6 +100,7 @@ class AltCLIPTextConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "altclip_text_model"
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment