"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fe06f8dcacfc3c8ffb39dcfc56a182b18c19e79d"
Unverified Commit 3385ca25 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Change REALM checkpoint to new ones (#15439)

* Change REALM checkpoint to new ones

* Last checkpoint missing
parent 7e56ba28
...@@ -21,14 +21,14 @@ from ...utils import logging ...@@ -21,14 +21,14 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = { REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json", "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json", "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json", "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json", "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json", "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json", "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json", "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json", "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/config.json",
# See all REALM models at https://huggingface.co/models?filter=realm # See all REALM models at https://huggingface.co/models?filter=realm
} }
...@@ -46,7 +46,7 @@ class RealmConfig(PretrainedConfig): ...@@ -46,7 +46,7 @@ class RealmConfig(PretrainedConfig):
It is used to instantiate an REALM model according to the specified arguments, defining the model architecture. It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
[realm-cc-news-pretrained](https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder) architecture. [realm-cc-news-pretrained](https://huggingface.co/google/realm-cc-news-pretrained-embedder) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
...@@ -112,7 +112,7 @@ class RealmConfig(PretrainedConfig): ...@@ -112,7 +112,7 @@ class RealmConfig(PretrainedConfig):
>>> # Initializing a REALM realm-cc-news-pretrained-* style configuration >>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
>>> configuration = RealmConfig() >>> configuration = RealmConfig()
>>> # Initializing a model from the qqaatw/realm-cc-news-pretrained-embedder style configuration >>> # Initializing a model from the google/realm-cc-news-pretrained-embedder style configuration
>>> model = RealmEmbedder(configuration) >>> model = RealmEmbedder(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
......
...@@ -43,21 +43,21 @@ from .configuration_realm import RealmConfig ...@@ -43,21 +43,21 @@ from .configuration_realm import RealmConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_EMBEDDER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-embedder" _EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-encoder" _ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-scorer" _SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer"
_CONFIG_FOR_DOC = "RealmConfig" _CONFIG_FOR_DOC = "RealmConfig"
_TOKENIZER_FOR_DOC = "RealmTokenizer" _TOKENIZER_FOR_DOC = "RealmTokenizer"
REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [ REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"qqaatw/realm-cc-news-pretrained-embedder", "google/realm-cc-news-pretrained-embedder",
"qqaatw/realm-cc-news-pretrained-encoder", "google/realm-cc-news-pretrained-encoder",
"qqaatw/realm-cc-news-pretrained-scorer", "google/realm-cc-news-pretrained-scorer",
"qqaatw/realm-cc-news-pretrained-openqa", "google/realm-cc-news-pretrained-openqa",
"qqaatw/realm-orqa-nq-openqa", "google/realm-orqa-nq-openqa",
"qqaatw/realm-orqa-nq-reader", "google/realm-orqa-nq-reader",
"qqaatw/realm-orqa-wq-openqa", "google/realm-orqa-wq-openqa",
"qqaatw/realm-orqa-wq-reader", "google/realm-orqa-wq-reader",
# See all REALM models at https://huggingface.co/models?filter=realm # See all REALM models at https://huggingface.co/models?filter=realm
] ]
...@@ -1180,8 +1180,8 @@ class RealmEmbedder(RealmPreTrainedModel): ...@@ -1180,8 +1180,8 @@ class RealmEmbedder(RealmPreTrainedModel):
>>> from transformers import RealmTokenizer, RealmEmbedder >>> from transformers import RealmTokenizer, RealmEmbedder
>>> import torch >>> import torch
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
>>> model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") >>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
...@@ -1293,8 +1293,8 @@ class RealmScorer(RealmPreTrainedModel): ...@@ -1293,8 +1293,8 @@ class RealmScorer(RealmPreTrainedModel):
>>> import torch >>> import torch
>>> from transformers import RealmTokenizer, RealmScorer >>> from transformers import RealmTokenizer, RealmScorer
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer") >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer")
>>> model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=2) >>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2)
>>> # batch_size = 2, num_candidates = 2 >>> # batch_size = 2, num_candidates = 2
>>> input_texts = ["How are you?", "What is the item in the picture?"] >>> input_texts = ["How are you?", "What is the item in the picture?"]
...@@ -1433,9 +1433,9 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel): ...@@ -1433,9 +1433,9 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
>>> import torch >>> import torch
>>> from transformers import RealmTokenizer, RealmKnowledgeAugEncoder >>> from transformers import RealmTokenizer, RealmKnowledgeAugEncoder
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
>>> model = RealmKnowledgeAugEncoder.from_pretrained( >>> model = RealmKnowledgeAugEncoder.from_pretrained(
... "qqaatw/realm-cc-news-pretrained-encoder", num_candidates=2 ... "google/realm-cc-news-pretrained-encoder", num_candidates=2
... ) ... )
>>> # batch_size = 2, num_candidates = 2 >>> # batch_size = 2, num_candidates = 2
...@@ -1761,9 +1761,9 @@ class RealmForOpenQA(RealmPreTrainedModel): ...@@ -1761,9 +1761,9 @@ class RealmForOpenQA(RealmPreTrainedModel):
>>> import torch >>> import torch
>>> from transformers import RealmForOpenQA, RealmRetriever, RealmTokenizer >>> from transformers import RealmForOpenQA, RealmRetriever, RealmTokenizer
>>> retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa") >>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa") >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
>>> model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa", retriever=retriever) >>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever)
>>> question = "Who is the pioneer in modern computer science?" >>> question = "Who is the pioneer in modern computer science?"
>>> question_ids = tokenizer([question], return_tensors="pt") >>> question_ids = tokenizer([question], return_tensors="pt")
......
...@@ -31,37 +31,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} ...@@ -31,37 +31,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"qqaatw/realm-cc-news-pretrained-embedder": 512, "google/realm-cc-news-pretrained-embedder": 512,
"qqaatw/realm-cc-news-pretrained-encoder": 512, "google/realm-cc-news-pretrained-encoder": 512,
"qqaatw/realm-cc-news-pretrained-scorer": 512, "google/realm-cc-news-pretrained-scorer": 512,
"qqaatw/realm-cc-news-pretrained-openqa": 512, "google/realm-cc-news-pretrained-openqa": 512,
"qqaatw/realm-orqa-nq-openqa": 512, "google/realm-orqa-nq-openqa": 512,
"qqaatw/realm-orqa-nq-reader": 512, "google/realm-orqa-nq-reader": 512,
"qqaatw/realm-orqa-wq-openqa": 512, "google/realm-orqa-wq-openqa": 512,
"qqaatw/realm-orqa-wq-reader": 512, "google/realm-orqa-wq-reader": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
"qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, "google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, "google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, "google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, "google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True}, "google/realm-orqa-nq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-reader": {"do_lower_case": True}, "google/realm-orqa-nq-reader": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True}, "google/realm-orqa-wq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-reader": {"do_lower_case": True}, "google/realm-orqa-wq-reader": {"do_lower_case": True},
} }
...@@ -252,7 +252,7 @@ class RealmTokenizer(PreTrainedTokenizer): ...@@ -252,7 +252,7 @@ class RealmTokenizer(PreTrainedTokenizer):
>>> # batch_size = 2, num_candidates = 2 >>> # batch_size = 2, num_candidates = 2
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
```""" ```"""
......
...@@ -32,47 +32,47 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso ...@@ -32,47 +32,47 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
}, },
"tokenizer_file": { "tokenizer_file": {
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont", "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json", "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json", "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json", "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/tokenizer.json", "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/tokenizer.json", "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json",
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/tokenizer.json", "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/tokenizer.json", "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json",
}, },
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"qqaatw/realm-cc-news-pretrained-embedder": 512, "google/realm-cc-news-pretrained-embedder": 512,
"qqaatw/realm-cc-news-pretrained-encoder": 512, "google/realm-cc-news-pretrained-encoder": 512,
"qqaatw/realm-cc-news-pretrained-scorer": 512, "google/realm-cc-news-pretrained-scorer": 512,
"qqaatw/realm-cc-news-pretrained-openqa": 512, "google/realm-cc-news-pretrained-openqa": 512,
"qqaatw/realm-orqa-nq-openqa": 512, "google/realm-orqa-nq-openqa": 512,
"qqaatw/realm-orqa-nq-reader": 512, "google/realm-orqa-nq-reader": 512,
"qqaatw/realm-orqa-wq-openqa": 512, "google/realm-orqa-wq-openqa": 512,
"qqaatw/realm-orqa-wq-reader": 512, "google/realm-orqa-wq-reader": 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
"qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, "google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, "google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, "google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
"qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, "google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True}, "google/realm-orqa-nq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-nq-reader": {"do_lower_case": True}, "google/realm-orqa-nq-reader": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True}, "google/realm-orqa-wq-openqa": {"do_lower_case": True},
"qqaatw/realm-orqa-wq-reader": {"do_lower_case": True}, "google/realm-orqa-wq-reader": {"do_lower_case": True},
} }
...@@ -200,7 +200,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast): ...@@ -200,7 +200,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
>>> # batch_size = 2, num_candidates = 2 >>> # batch_size = 2, num_candidates = 2
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
>>> tokenizer = RealmTokenizerFast.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") >>> tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-encoder")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
```""" ```"""
......
...@@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4] input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
config.return_dict = True config.return_dict = True
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa") tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
# RealmKnowledgeAugEncoder training # RealmKnowledgeAugEncoder training
model = RealmKnowledgeAugEncoder(config) model = RealmKnowledgeAugEncoder(config)
...@@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
@slow @slow
def test_embedder_from_pretrained(self): def test_embedder_from_pretrained(self):
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow @slow
def test_encoder_from_pretrained(self): def test_encoder_from_pretrained(self):
model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow @slow
def test_open_qa_from_pretrained(self): def test_open_qa_from_pretrained(self):
model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa") model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow @slow
def test_reader_from_pretrained(self): def test_reader_from_pretrained(self):
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader") model = RealmReader.from_pretrained("google/realm-orqa-nq-reader")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow @slow
def test_scorer_from_pretrained(self): def test_scorer_from_pretrained(self):
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer") model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer")
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase): ...@@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_embedder(self): def test_inference_embedder(self):
retriever_projected_size = 128 retriever_projected_size = 128
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0] output = model(input_ids)[0]
...@@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase): ...@@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
vocab_size = 30522 vocab_size = 30522
model = RealmKnowledgeAugEncoder.from_pretrained( model = RealmKnowledgeAugEncoder.from_pretrained(
"qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates "google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
) )
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]]) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32) relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
...@@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase): ...@@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase):
config = RealmConfig() config = RealmConfig()
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa") tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa") retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
model = RealmForOpenQA.from_pretrained( model = RealmForOpenQA.from_pretrained(
"qqaatw/realm-orqa-nq-openqa", "google/realm-orqa-nq-openqa",
retriever=retriever, retriever=retriever,
config=config, config=config,
) )
...@@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase): ...@@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference_reader(self): def test_inference_reader(self):
config = RealmConfig(reader_beam_size=2, max_span_width=3) config = RealmConfig(reader_beam_size=2, max_span_width=3)
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config) model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config)
concat_input_ids = torch.arange(10).view((2, 5)) concat_input_ids = torch.arange(10).view((2, 5))
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64) concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
...@@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase): ...@@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def test_inference_scorer(self): def test_inference_scorer(self):
num_candidates = 2 num_candidates = 2
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates) model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]]) candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
......
...@@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase): ...@@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase):
mock_hf_hub_download.return_value = os.path.join( mock_hf_hub_download.return_value = os.path.join(
os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
) )
retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa") retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa")
self.assertEqual(retriever.block_records[0], b"This is the first record") self.assertEqual(retriever.block_records[0], b"This is the first record")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment