Unverified Commit 9a498c37 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Rely on huggingface_hub for common tools (#13100)

* Remove hf_api module and use hugginface_hub

* Style

* Fix to test_fetcher

* Quality
parent 6900dded
...@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub: ...@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub:
.. code-block:: python .. code-block:: python
from transformers.hf_api import HfApi from huggingface_hub.hf_api import HfApi
model_list = HfApi().model_list() model_list = HfApi().list_models()
org = "Helsinki-NLP" org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids] suffix = [x.split('/')[1] for x in model_ids]
......
...@@ -14,10 +14,10 @@ import lightning_base ...@@ -14,10 +14,10 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main from distillation import distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from huggingface_hub.hf_api import HfApi
from parameterized import parameterized from parameterized import parameterized
from run_eval import generate_summaries_or_translations from run_eval import generate_summaries_or_translations
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
from utils import label_smoothed_nll_loss, lmap, load_json from utils import label_smoothed_nll_loss, lmap, load_json
...@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus):
def test_hub_configs(self): def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled.""" """I put require_torch_gpu cause I only want this to run with self-scheduled."""
model_list = HfApi().model_list() model_list = HfApi().list_models()
org = "sshleifer" org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"] allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
......
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
import os import os
import subprocess import subprocess
import sys import sys
import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
from typing import List, Union from typing import List, Union
from huggingface_hub.hf_api import HfApi, HfFolder
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from ..hf_api import HfApi, HfFolder
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -148,6 +147,12 @@ class BaseUserCommand: ...@@ -148,6 +147,12 @@ class BaseUserCommand:
class LoginCommand(BaseUserCommand): class LoginCommand(BaseUserCommand):
def run(self): def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli login` is deprecated and will be removed in v5. Please use "
"`huggingface-cli login` instead."
)
)
print( # docstyle-ignore print( # docstyle-ignore
""" """
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
...@@ -175,6 +180,12 @@ class LoginCommand(BaseUserCommand): ...@@ -175,6 +180,12 @@ class LoginCommand(BaseUserCommand):
class WhoamiCommand(BaseUserCommand): class WhoamiCommand(BaseUserCommand):
def run(self): def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
"`huggingface-cli whoami` instead."
)
)
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
print("Not logged in") print("Not logged in")
...@@ -192,6 +203,12 @@ class WhoamiCommand(BaseUserCommand): ...@@ -192,6 +203,12 @@ class WhoamiCommand(BaseUserCommand):
class LogoutCommand(BaseUserCommand): class LogoutCommand(BaseUserCommand):
def run(self): def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli logout` is deprecated and will be removed in v5. Please use "
"`huggingface-cli logout` instead."
)
)
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
print("Not logged in") print("Not logged in")
...@@ -203,8 +220,11 @@ class LogoutCommand(BaseUserCommand): ...@@ -203,8 +220,11 @@ class LogoutCommand(BaseUserCommand):
class ListObjsCommand(BaseUserCommand): class ListObjsCommand(BaseUserCommand):
def run(self): def run(self):
warnings.warn( print(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead." ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
) )
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
...@@ -225,8 +245,11 @@ class ListObjsCommand(BaseUserCommand): ...@@ -225,8 +245,11 @@ class ListObjsCommand(BaseUserCommand):
class DeleteObjCommand(BaseUserCommand): class DeleteObjCommand(BaseUserCommand):
def run(self): def run(self):
warnings.warn( print(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead." ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
) )
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
...@@ -243,8 +266,11 @@ class DeleteObjCommand(BaseUserCommand): ...@@ -243,8 +266,11 @@ class DeleteObjCommand(BaseUserCommand):
class ListReposObjsCommand(BaseUserCommand): class ListReposObjsCommand(BaseUserCommand):
def run(self): def run(self):
warnings.warn( print(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead." ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
) )
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
...@@ -265,8 +291,11 @@ class ListReposObjsCommand(BaseUserCommand): ...@@ -265,8 +291,11 @@ class ListReposObjsCommand(BaseUserCommand):
class RepoCreateCommand(BaseUserCommand): class RepoCreateCommand(BaseUserCommand):
def run(self): def run(self):
warnings.warn( print(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead." ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
) )
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
...@@ -339,8 +368,11 @@ class UploadCommand(BaseUserCommand): ...@@ -339,8 +368,11 @@ class UploadCommand(BaseUserCommand):
return files return files
def run(self): def run(self):
warnings.warn( print(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead." ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
) )
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from os.path import expanduser
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import requests
ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class ModelSibling:
"""
Data structure that represents a public file inside a model, accessible from huggingface.co
"""
def __init__(self, rfilename: str, **kwargs):
self.rfilename = rfilename # filename relative to the model root
for k, v in kwargs.items():
setattr(self, k, v)
class ModelInfo:
"""
Info about a public model accessible from huggingface.co
"""
def __init__(
self,
modelId: Optional[str] = None, # id of model
tags: List[str] = [],
pipeline_tag: Optional[str] = None,
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.tags = tags
self.pipeline_tag = pipeline_tag
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items():
setattr(self, k, v)
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str:
"""
Call HF API to sign in a user and get a token if credentials are valid.
Outputs: token if credentials are valid
Throws: requests.exceptions.HTTPError if credentials are invalid
"""
path = f"{self.endpoint}/api/login"
r = requests.post(path, json={"username": username, "password": password})
r.raise_for_status()
d = r.json()
return d["token"]
def whoami(self, token: str) -> Tuple[str, List[str]]:
"""
Call HF API to know "whoami"
"""
path = f"{self.endpoint}/api/whoami"
r = requests.get(path, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
d = r.json()
return d["user"], d["orgs"]
def logout(self, token: str) -> None:
"""
Call HF API to log out.
"""
path = f"{self.endpoint}/api/logout"
r = requests.post(path, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface.co
"""
path = f"{self.endpoint}/api/models"
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = f"{self.endpoint}/api/repos/ls"
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = f"{self.endpoint}/api/repos/create"
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
path,
headers={"authorization": f"Bearer {token}"},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = f"{self.endpoint}/api/repos/delete"
r = requests.delete(
path,
headers={"authorization": f"Bearer {token}"},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
tqdm progress bar.
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
"""
def __init__(self, f: io.BufferedReader):
self.f = f
self.total_size = os.fstat(f.fileno()).st_size
self.pbar = tqdm(total=self.total_size, leave=False)
self.read = f.read
f.read = self._read
def _read(self, n=-1):
self.pbar.update(n)
return self.read(n)
def close(self):
self.pbar.close()
class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
with open(cls.path_token, "w+") as f:
f.write(token)
@classmethod
def get_token(cls):
"""
Get token or None if not existent.
"""
try:
with open(cls.path_token, "r") as f:
return f.read()
except FileNotFoundError:
pass
@classmethod
def delete_token(cls):
"""
Delete token. Do not fail if token does not exist.
"""
try:
os.remove(cls.path_token)
except FileNotFoundError:
pass
...@@ -27,8 +27,8 @@ import torch ...@@ -27,8 +27,8 @@ import torch
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from huggingface_hub.hf_api import HfApi
from transformers import MarianConfig, MarianMTModel, MarianTokenizer from transformers import MarianConfig, MarianMTModel, MarianTokenizer
from transformers.hf_api import HfApi
def remove_suffix(text: str, suffix: str): def remove_suffix(text: str, suffix: str):
...@@ -65,7 +65,7 @@ def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]: ...@@ -65,7 +65,7 @@ def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""Find models that can accept src_lang as input and return tgt_lang as output.""" """Find models that can accept src_lang as input and return tgt_lang as output."""
prefix = "Helsinki-NLP/opus-mt-" prefix = "Helsinki-NLP/opus-mt-"
api = HfApi() api = HfApi()
model_list = api.model_list() model_list = api.list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")] model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
src_and_targ = [ src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import subprocess
import time
import unittest
from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test, require_git_lfs
ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co"
REPO_NAME = f"my-model-{int(time.time())}"
REPO_NAME_LARGE_FILE = f"my-model-largefiles-{int(time.time())}"
WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo")
LARGE_FILE_14MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.epub"
LARGE_FILE_18MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.pdf"
class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint=ENDPOINT_STAGING)
class HfApiLoginTest(HfApiCommonTest):
def test_login_invalid(self):
with self.assertRaises(HTTPError):
self._api.login(username=USER, password="fake")
def test_login_valid(self):
token = self._api.login(username=USER, password=PASS)
self.assertIsInstance(token, str)
class HfApiEndpointsTest(HfApiCommonTest):
@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)
def test_whoami(self):
user, orgs = self._api.whoami(token=self._token)
self.assertEqual(user, USER)
self.assertIsInstance(orgs, list)
def test_list_repos_objs(self):
objs = self._api.list_repos_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, RepoObj)
def test_create_and_delete_repo(self):
self._api.create_repo(token=self._token, name=REPO_NAME)
self._api.delete_repo(token=self._token, name=REPO_NAME)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):
_api = HfApi(endpoint=ENDPOINT_STAGING)
_ = _api.model_list()
def test_model_list(self):
_api = HfApi()
models = _api.model_list()
self.assertGreater(len(models), 100)
self.assertIsInstance(models[0], ModelInfo)
class HfFolderTest(unittest.TestCase):
def test_token_workflow(self):
"""
Test the whole token save/get/delete workflow,
with the desired behavior with respect to non-existent tokens.
"""
token = f"token-{int(time.time())}"
HfFolder.save_token(token)
self.assertEqual(HfFolder.get_token(), token)
HfFolder.delete_token()
HfFolder.delete_token()
# ^^ not an error, we test that the
# second call does not fail.
self.assertEqual(HfFolder.get_token(), None)
@require_git_lfs
@is_staging_test
class HfLargefilesTest(HfApiCommonTest):
@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)
def setUp(self):
try:
shutil.rmtree(WORKING_REPO_DIR)
except FileNotFoundError:
pass
def tearDown(self):
self._api.delete_repo(token=self._token, name=REPO_NAME_LARGE_FILE)
def setup_local_clone(self, REMOTE_URL):
REMOTE_URL_AUTH = REMOTE_URL.replace(ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH)
subprocess.run(["git", "clone", REMOTE_URL_AUTH, WORKING_REPO_DIR], check=True, capture_output=True)
subprocess.run(["git", "lfs", "track", "*.pdf"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "lfs", "track", "*.epub"], check=True, cwd=WORKING_REPO_DIR)
def test_end_to_end_thresh_6M(self):
REMOTE_URL = self._api.create_repo(
token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=6 * 10 ** 6
)
self.setup_local_clone(REMOTE_URL)
subprocess.run(["wget", LARGE_FILE_18MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "commit", "-m", "commit message"], check=True, cwd=WORKING_REPO_DIR)
# This will fail as we haven't set up our custom transfer agent yet.
failed_process = subprocess.run(["git", "push"], capture_output=True, cwd=WORKING_REPO_DIR)
self.assertEqual(failed_process.returncode, 1)
self.assertIn("transformers-cli lfs-enable-largefiles", failed_process.stderr.decode())
# ^ Instructions on how to fix this are included in the error message.
subprocess.run(["transformers-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True)
start_time = time.time()
subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR)
print("took", time.time() - start_time)
# To be 100% sure, let's download the resolved file
pdf_url = f"{REMOTE_URL}/resolve/main/progit.pdf"
DEST_FILENAME = "uploaded.pdf"
subprocess.run(["wget", pdf_url, "-O", DEST_FILENAME], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
dest_filesize = os.stat(os.path.join(WORKING_REPO_DIR, DEST_FILENAME)).st_size
self.assertEqual(dest_filesize, 18685041)
def test_end_to_end_thresh_16M(self):
# Here we'll push one multipart and one non-multipart file in the same commit, and see what happens
REMOTE_URL = self._api.create_repo(
token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=16 * 10 ** 6
)
self.setup_local_clone(REMOTE_URL)
subprocess.run(["wget", LARGE_FILE_18MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
subprocess.run(["wget", LARGE_FILE_14MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "commit", "-m", "both files in same commit"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["transformers-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True)
start_time = time.time()
subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR)
print("took", time.time() - start_time)
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
import tempfile import tempfile
import unittest import unittest
from huggingface_hub.hf_api import HfApi
from transformers import MarianConfig, is_torch_available from transformers import MarianConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase): ...@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
def test_model_names(self): def test_model_names(self):
model_list = HfApi().model_list() model_list = HfApi().list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids] bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids) self.assertListEqual([], bad_model_ids)
......
...@@ -412,6 +412,8 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None): ...@@ -412,6 +412,8 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
# Remove duplicates # Remove duplicates
test_files_to_run = sorted(list(set(test_files_to_run))) test_files_to_run = sorted(list(set(test_files_to_run)))
# Make sure we did not end up with a test file that was removed
test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
if filters is not None: if filters is not None:
for filter in filters: for filter in filters:
test_files_to_run = [f for f in test_files_to_run if f.startswith(filter)] test_files_to_run = [f for f in test_files_to_run if f.startswith(filter)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment