Unverified Commit a73883ae authored by Radamés Ajna's avatar Radamés Ajna Committed by GitHub
Browse files

add trust_remote_code option to CLI download cmd (#24097)

* add trust_remote_code option

* require_torch
parent 8b169142
...@@ -18,7 +18,7 @@ from . import BaseTransformersCLICommand ...@@ -18,7 +18,7 @@ from . import BaseTransformersCLICommand
def download_command_factory(args): def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force) return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code)
class DownloadCommand(BaseTransformersCLICommand): class DownloadCommand(BaseTransformersCLICommand):
...@@ -31,16 +31,26 @@ class DownloadCommand(BaseTransformersCLICommand): ...@@ -31,16 +31,26 @@ class DownloadCommand(BaseTransformersCLICommand):
download_parser.add_argument( download_parser.add_argument(
"--force", action="store_true", help="Force the model to be download even if already in cache-dir" "--force", action="store_true", help="Force the model to be download even if already in cache-dir"
) )
download_parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine",
)
download_parser.add_argument("model", type=str, help="Name of the model to download") download_parser.add_argument("model", type=str, help="Name of the model to download")
download_parser.set_defaults(func=download_command_factory) download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool): def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool):
self._model = model self._model = model
self._cache = cache self._cache = cache
self._force = force self._force = force
self._trust_remote_code = trust_remote_code
def run(self): def run(self):
from ..models.auto import AutoModel, AutoTokenizer from ..models.auto import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) AutoModel.from_pretrained(
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
)
AutoTokenizer.from_pretrained(
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
)
...@@ -18,7 +18,7 @@ import shutil ...@@ -18,7 +18,7 @@ import shutil
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test, require_torch
class CLITest(unittest.TestCase): class CLITest(unittest.TestCase):
...@@ -45,3 +45,47 @@ class CLITest(unittest.TestCase): ...@@ -45,3 +45,47 @@ class CLITest(unittest.TestCase):
# The original repo has no TF weights -- if they exist, they were created by the CLI # The original repo has no TF weights -- if they exist, they were created by the CLI
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5")) self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
@require_torch
@patch("sys.argv", ["fakeprogrampath", "download", "hf-internal-testing/tiny-random-gptj", "--cache-dir", "/tmp"])
def test_cli_download(self):
import transformers.commands.transformers_cli
# # remove any previously downloaded model to start clean
shutil.rmtree("/tmp/models--hf-internal-testing--tiny-random-gptj", ignore_errors=True)
# run the command
transformers.commands.transformers_cli.main()
# check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--tiny-random-gptj
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/blobs"))
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/refs"))
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/snapshots"))
@require_torch
@patch(
"sys.argv",
[
"fakeprogrampath",
"download",
"hf-internal-testing/test_dynamic_model_with_tokenizer",
"--trust-remote-code",
"--cache-dir",
"/tmp",
],
)
def test_cli_download_trust_remote(self):
import transformers.commands.transformers_cli
# # remove any previously downloaded model to start clean
shutil.rmtree("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer", ignore_errors=True)
# run the command
transformers.commands.transformers_cli.main()
# check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/blobs"))
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/refs"))
self.assertTrue(
os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/snapshots")
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment