Unverified Commit 3385ca25 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Change REALM checkpoint to new ones (#15439)

* Change REALM checkpoint to new ones

* Last checkpoint missing
parent 7e56ba28
......@@ -21,14 +21,14 @@ from ...utils import logging
logger = logging.get_logger(__name__)
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json",
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json",
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json",
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json",
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/config.json",
# See all REALM models at https://huggingface.co/models?filter=realm
}
......@@ -46,7 +46,7 @@ class RealmConfig(PretrainedConfig):
It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
[realm-cc-news-pretrained](https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder) architecture.
[realm-cc-news-pretrained](https://huggingface.co/google/realm-cc-news-pretrained-embedder) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
......@@ -112,7 +112,7 @@ class RealmConfig(PretrainedConfig):
>>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
>>> configuration = RealmConfig()
>>> # Initializing a model from the qqaatw/realm-cc-news-pretrained-embedder style configuration
>>> # Initializing a model from the google/realm-cc-news-pretrained-embedder style configuration
>>> model = RealmEmbedder(configuration)
>>> # Accessing the model configuration
......
......@@ -43,21 +43,21 @@ from .configuration_realm import RealmConfig
logger = logging.get_logger(__name__)
_EMBEDDER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-scorer"
_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer"
_CONFIG_FOR_DOC = "RealmConfig"
_TOKENIZER_FOR_DOC = "RealmTokenizer"
REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"qqaatw/realm-cc-news-pretrained-embedder",
"qqaatw/realm-cc-news-pretrained-encoder",
"qqaatw/realm-cc-news-pretrained-scorer",
"qqaatw/realm-cc-news-pretrained-openqa",
"qqaatw/realm-orqa-nq-openqa",
"qqaatw/realm-orqa-nq-reader",
"qqaatw/realm-orqa-wq-openqa",
"qqaatw/realm-orqa-wq-reader",
"google/realm-cc-news-pretrained-embedder",
"google/realm-cc-news-pretrained-encoder",
"google/realm-cc-news-pretrained-scorer",
"google/realm-cc-news-pretrained-openqa",
"google/realm-orqa-nq-openqa",
"google/realm-orqa-nq-reader",
"google/realm-orqa-wq-openqa",
"google/realm-orqa-wq-reader",
# See all REALM models at https://huggingface.co/models?filter=realm
]
......@@ -1180,8 +1180,8 @@ class RealmEmbedder(RealmPreTrainedModel):
>>> from transformers import RealmTokenizer, RealmEmbedder
>>> import torch
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
>>> model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
>>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
......@@ -1293,8 +1293,8 @@ class RealmScorer(RealmPreTrainedModel):
>>> import torch
>>> from transformers import RealmTokenizer, RealmScorer
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
>>> model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=2)
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer")
>>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2)
>>> # batch_size = 2, num_candidates = 2
>>> input_texts = ["How are you?", "What is the item in the picture?"]
......@@ -1433,9 +1433,9 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
>>> import torch
>>> from transformers import RealmTokenizer, RealmKnowledgeAugEncoder
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
>>> model = RealmKnowledgeAugEncoder.from_pretrained(
... "qqaatw/realm-cc-news-pretrained-encoder", num_candidates=2
... "google/realm-cc-news-pretrained-encoder", num_candidates=2
... )
>>> # batch_size = 2, num_candidates = 2
......@@ -1761,9 +1761,9 @@ class RealmForOpenQA(RealmPreTrainedModel):
>>> import torch
>>> from transformers import RealmForOpenQA, RealmRetriever, RealmTokenizer
>>> retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
>>> model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa", retriever=retriever)
>>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
>>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever)
>>> question = "Who is the pioneer in modern computer science?"
>>> question_ids = tokenizer([question], return_tensors="pt")
......
......@@ -31,37 +31,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"qqaatw/realm-cc-news-pretrained-embedder": 512,
"qqaatw/realm-cc-news-pretrained-encoder": 512,
"qqaatw/realm-cc-news-pretrained-scorer": 512,
"qqaatw/realm-cc-news-pretrained-openqa": 512,
"qqaatw/realm-orqa-nq-openqa": 512,
"qqaatw/realm-orqa-nq-reader": 512,
"qqaatw/realm-orqa-wq-openqa": 512,
"qqaatw/realm-orqa-wq-reader": 512,
"google/realm-cc-news-pretrained-embedder": 512,
"google/realm-cc-news-pretrained-encoder": 512,
"google/realm-cc-news-pretrained-scorer": 512,
"google/realm-cc-news-pretrained-openqa": 512,
"google/realm-orqa-nq-openqa": 512,
"google/realm-orqa-nq-reader": 512,
"google/realm-orqa-wq-openqa": 512,
"google/realm-orqa-wq-reader": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-reader": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-reader": {"do_lower_case": True},
"google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-reader": {"do_lower_case": True},
"google/realm-orqa-wq-openqa": {"do_lower_case": True},
"google/realm-orqa-wq-reader": {"do_lower_case": True},
}
......@@ -252,7 +252,7 @@ class RealmTokenizer(PreTrainedTokenizer):
>>> # batch_size = 2, num_candidates = 2
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
```"""
......
......@@ -32,47 +32,47 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
},
"tokenizer_file": {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/tokenizer.json",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/tokenizer.json",
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json",
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"qqaatw/realm-cc-news-pretrained-embedder": 512,
"qqaatw/realm-cc-news-pretrained-encoder": 512,
"qqaatw/realm-cc-news-pretrained-scorer": 512,
"qqaatw/realm-cc-news-pretrained-openqa": 512,
"qqaatw/realm-orqa-nq-openqa": 512,
"qqaatw/realm-orqa-nq-reader": 512,
"qqaatw/realm-orqa-wq-openqa": 512,
"qqaatw/realm-orqa-wq-reader": 512,
"google/realm-cc-news-pretrained-embedder": 512,
"google/realm-cc-news-pretrained-encoder": 512,
"google/realm-cc-news-pretrained-scorer": 512,
"google/realm-cc-news-pretrained-openqa": 512,
"google/realm-orqa-nq-openqa": 512,
"google/realm-orqa-nq-reader": 512,
"google/realm-orqa-wq-openqa": 512,
"google/realm-orqa-wq-reader": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-reader": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-reader": {"do_lower_case": True},
"google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-openqa": {"do_lower_case": True},
"google/realm-orqa-nq-reader": {"do_lower_case": True},
"google/realm-orqa-wq-openqa": {"do_lower_case": True},
"google/realm-orqa-wq-reader": {"do_lower_case": True},
}
......@@ -200,7 +200,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
>>> # batch_size = 2, num_candidates = 2
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
>>> tokenizer = RealmTokenizerFast.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
>>> tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-encoder")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
```"""
......
......@@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
config.return_dict = True
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
# RealmKnowledgeAugEncoder training
model = RealmKnowledgeAugEncoder(config)
......@@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
@slow
def test_embedder_from_pretrained(self):
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
self.assertIsNotNone(model)
@slow
def test_encoder_from_pretrained(self):
model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder")
self.assertIsNotNone(model)
@slow
def test_open_qa_from_pretrained(self):
model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa")
model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa")
self.assertIsNotNone(model)
@slow
def test_reader_from_pretrained(self):
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader")
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader")
self.assertIsNotNone(model)
@slow
def test_scorer_from_pretrained(self):
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer")
self.assertIsNotNone(model)
......@@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_embedder(self):
retriever_projected_size = 128
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
......@@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
vocab_size = 30522
model = RealmKnowledgeAugEncoder.from_pretrained(
"qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
"google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
......@@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase):
config = RealmConfig()
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
model = RealmForOpenQA.from_pretrained(
"qqaatw/realm-orqa-nq-openqa",
"google/realm-orqa-nq-openqa",
retriever=retriever,
config=config,
)
......@@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_reader(self):
config = RealmConfig(reader_beam_size=2, max_span_width=3)
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config)
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config)
concat_input_ids = torch.arange(10).view((2, 5))
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
......@@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_scorer(self):
num_candidates = 2
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
......
......@@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase):
mock_hf_hub_download.return_value = os.path.join(
os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
)
retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa")
retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa")
self.assertEqual(retriever.block_records[0], b"This is the first record")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment