Unverified Commit 1d63b0ec authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Disallow `pickle.load` unless `TRUST_REMOTE_CODE=True` (#27776)



* fix

* fix

* Use TRUST_REMOTE_CODE

* fix doc

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e0d2e695
......@@ -22,11 +22,17 @@ This model is in maintenance mode only, so we won't accept any new PRs changing
We recommend switching to more recent models for improved security.
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub:
In case you would still like to use `TransfoXL` in your experiments, we recommend using the [Hub checkpoint](https://huggingface.co/transfo-xl-wt103) with a specific revision to ensure you are downloading safe files from the Hub.
```
You will need to set the environment variable `TRUST_REMOTE_CODE` to `True` in order to allow the
usage of `pickle.load()`:
```python
import os
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel
os.environ["TRUST_REMOTE_CODE"] = "True"
checkpoint = 'transfo-xl-wt103'
revision = '40a186da79458c9f9de846edfaea79c412137f97'
......
......@@ -34,6 +34,7 @@ from ....utils import (
is_torch_available,
logging,
requires_backends,
strtobool,
torch_only_method,
)
......@@ -212,6 +213,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
vocab_dict = None
if pretrained_vocab_file is not None:
# Priority on pickle files (support PyTorch and TF)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
"potentially malicious. It's recommended to never unpickle data that could have come from an "
"untrusted source, or that could have been tampered with. If you already verified the pickle "
"data and decided to use it, you can set the environment variable "
"`TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(pretrained_vocab_file, "rb") as f:
vocab_dict = pickle.load(f)
......@@ -790,6 +799,13 @@ def get_lm_corpus(datadir, dataset):
corpus = torch.load(fn_pickle)
elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(fn, "rb") as fp:
corpus = pickle.load(fp)
else:
......
......@@ -23,7 +23,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer
......@@ -131,6 +131,13 @@ class LegacyIndex(Index):
def _load_passages(self):
logger.info(f"Loading passages from {self.index_path}")
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(passages_path, "rb") as passages_file:
passages = pickle.load(passages_file)
return passages
......@@ -140,6 +147,13 @@ class LegacyIndex(Index):
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
self.index = faiss.read_index(resolved_index_path)
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(resolved_meta_path, "rb") as metadata_file:
self.index_id_to_db_id = pickle.load(metadata_file)
assert (
......
......@@ -14,7 +14,6 @@
import json
import os
import pickle
import shutil
import tempfile
from unittest import TestCase
......@@ -174,37 +173,6 @@ class RagRetrieverTest(TestCase):
)
return retriever
def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict(
{
"id": ["0", "1"],
"text": ["foo", "bar"],
"title": ["Foo", "Bar"],
"embeddings": [np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1)],
}
)
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
index_file_name = os.path.join(self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb"))
passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
passages = {sample["id"]: [sample["text"], sample["title"]] for sample in dataset}
pickle.dump(passages, open(passages_file_name, "wb"))
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="legacy",
index_path=self.tmpdirname,
)
retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
)
return retriever
def test_canonical_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever()
......@@ -288,33 +256,6 @@ class RagRetrieverTest(TestCase):
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_legacy_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_legacy_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["text", "title"])
self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_legacy_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
@require_torch
@require_tokenizers
@require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment