Unverified Commit 1d63b0ec authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Disallow `pickle.load` unless `TRUST_REMOTE_CODE=True` (#27776)



* fix

* fix

* Use TRUST_REMOTE_CODE

* fix doc

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e0d2e695
...@@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing ...@@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing
We recommend switching to more recent models for improved security. We recommend switching to more recent models for improved security.
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub: In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub.
``` You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the
usage of `pickle.load()`:
```python
import os
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel
os.environ["TRUST_REMOTE_CODE"] = "True"
checkpoint = 'transfo-xl-wt103' checkpoint = 'transfo-xl-wt103'
revision = '40a186da79458c9f9de846edfaea79c412137f97' revision = '40a186da79458c9f9de846edfaea79c412137f97'
......
...@@ -34,6 +34,7 @@ from ....utils import ( ...@@ -34,6 +34,7 @@ from ....utils import (
is_torch_available, is_torch_available,
logging, logging,
requires_backends, requires_backends,
strtobool,
torch_only_method, torch_only_method,
) )
...@@ -212,6 +213,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -212,6 +213,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_dict = None vocab_dict = None
if pretrained_vocab_file is not None: if pretrained_vocab_file is not None:
# Priority on pickle files (support PyTorch and TF) # Priority on pickle files (support PyTorch and TF)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
"potentially malicious. It's recommended to never unpickle data that could have come from an "
"untrusted source, or that could have been tampered with. If you already verified the pickle "
"data and decided to use it, you can set the environment variable "
"`TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(pretrained_vocab_file, "rb") as f: with open(pretrained_vocab_file, "rb") as f:
vocab_dict = pickle.load(f) vocab_dict = pickle.load(f)
...@@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset): ...@@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset):
corpus = torch.load(fn_pickle) corpus = torch.load(fn_pickle)
elif os.path.exists(fn): elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...") logger.info("Loading cached dataset from pickle...")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(fn, "rb") as fp: with open(fn, "rb") as fp:
corpus = pickle.load(fp) corpus = pickle.load(fp)
else: else:
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer from .tokenization_rag import RagTokenizer
...@@ -131,6 +131,13 @@ class LegacyIndex(Index): ...@@ -131,6 +131,13 @@ class LegacyIndex(Index):
def _load_passages(self): def _load_passages(self):
logger.info(f"Loading passages from {self.index_path}") logger.info(f"Loading passages from {self.index_path}")
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME) passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(passages_path, "rb") as passages_file: with open(passages_path, "rb") as passages_file:
passages = pickle.load(passages_file) passages = pickle.load(passages_file)
return passages return passages
...@@ -140,6 +147,13 @@ class LegacyIndex(Index): ...@@ -140,6 +147,13 @@ class LegacyIndex(Index):
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr") resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
self.index = faiss.read_index(resolved_index_path) self.index = faiss.read_index(resolved_index_path)
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr") resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(resolved_meta_path, "rb") as metadata_file: with open(resolved_meta_path, "rb") as metadata_file:
self.index_id_to_db_id = pickle.load(metadata_file) self.index_id_to_db_id = pickle.load(metadata_file)
assert ( assert (
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import json import json
import os import os
import pickle
import shutil import shutil
import tempfile import tempfile
from unittest import TestCase from unittest import TestCase
...@@ -174,37 +173,6 @@ class RagRetrieverTest(TestCase): ...@@ -174,37 +173,6 @@ class RagRetrieverTest(TestCase):
) )
return retriever return retriever
def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever
def test_canonical_hf_index_retriever_retrieve(self): def test_canonical_hf_index_retriever_retrieve(self):
n_docs = 1 n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever() retriever = self.get_dummy_canonical_hf_index_retriever()
...@@ -288,33 +256,6 @@ class RagRetrieverTest(TestCase): ...@@ -288,33 +256,6 @@ class RagRetrieverTest(TestCase):
out = retriever.retrieve(hidden_states, n_docs=1) out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None) self.assertTrue(out is not None)
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_legacy_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
@require_torch @require_torch
@require_tokenizers @require_tokenizers
@require_sentencepiece @require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment