Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3385ca25
Unverified
Commit
3385ca25
authored
Jan 31, 2022
by
Sylvain Gugger
Committed by
GitHub
Jan 31, 2022
Browse files
Change REALM checkpoint to new ones (#15439)
* Change REALM checkpoint to new ones * Last checkpoint missing
parent
7e56ba28
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
102 additions
and
102 deletions
+102
-102
src/transformers/models/realm/configuration_realm.py
src/transformers/models/realm/configuration_realm.py
+10
-10
src/transformers/models/realm/modeling_realm.py
src/transformers/models/realm/modeling_realm.py
+20
-20
src/transformers/models/realm/tokenization_realm.py
src/transformers/models/realm/tokenization_realm.py
+25
-25
src/transformers/models/realm/tokenization_realm_fast.py
src/transformers/models/realm/tokenization_realm_fast.py
+33
-33
tests/test_modeling_realm.py
tests/test_modeling_realm.py
+13
-13
tests/test_retrieval_realm.py
tests/test_retrieval_realm.py
+1
-1
No files found.
src/transformers/models/realm/configuration_realm.py
View file @
3385ca25
...
...
@@ -21,14 +21,14 @@ from ...utils import logging
logger
=
logging
.
get_logger
(
__name__
)
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-embedder/resolve/main/config.json"
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-encoder/resolve/main/config.json"
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-scorer/resolve/main/config.json"
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-openqa/aresolve/main/config.json"
,
"
qqaatw
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-openqa/resolve/main/config.json"
,
"
qqaatw
/realm-orqa-nq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-reader/resolve/main/config.json"
,
"
qqaatw
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-openqa/resolve/main/config.json"
,
"
qqaatw
/realm-orqa-wq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-reader/resolve/main/config.json"
,
"
google
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-embedder/resolve/main/config.json"
,
"
google
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-encoder/resolve/main/config.json"
,
"
google
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-scorer/resolve/main/config.json"
,
"
google
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-openqa/aresolve/main/config.json"
,
"
google
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-nq-openqa/resolve/main/config.json"
,
"
google
/realm-orqa-nq-reader"
:
"https://huggingface.co/
google
/realm-orqa-nq-reader/resolve/main/config.json"
,
"
google
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-wq-openqa/resolve/main/config.json"
,
"
google
/realm-orqa-wq-reader"
:
"https://huggingface.co/
google
/realm-orqa-wq-reader/resolve/main/config.json"
,
# See all REALM models at https://huggingface.co/models?filter=realm
}
...
...
@@ -46,7 +46,7 @@ class RealmConfig(PretrainedConfig):
It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
[realm-cc-news-pretrained](https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-embedder) architecture.
[realm-cc-news-pretrained](https://huggingface.co/
google
/realm-cc-news-pretrained-embedder) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
...
...
@@ -112,7 +112,7 @@ class RealmConfig(PretrainedConfig):
>>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
>>> configuration = RealmConfig()
>>> # Initializing a model from the
qqaatw
/realm-cc-news-pretrained-embedder style configuration
>>> # Initializing a model from the
google
/realm-cc-news-pretrained-embedder style configuration
>>> model = RealmEmbedder(configuration)
>>> # Accessing the model configuration
...
...
src/transformers/models/realm/modeling_realm.py
View file @
3385ca25
...
...
@@ -43,21 +43,21 @@ from .configuration_realm import RealmConfig
logger
=
logging
.
get_logger
(
__name__
)
_EMBEDDER_CHECKPOINT_FOR_DOC
=
"
qqaatw
/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC
=
"
qqaatw
/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC
=
"
qqaatw
/realm-cc-news-pretrained-scorer"
_EMBEDDER_CHECKPOINT_FOR_DOC
=
"
google
/realm-cc-news-pretrained-embedder"
_ENCODER_CHECKPOINT_FOR_DOC
=
"
google
/realm-cc-news-pretrained-encoder"
_SCORER_CHECKPOINT_FOR_DOC
=
"
google
/realm-cc-news-pretrained-scorer"
_CONFIG_FOR_DOC
=
"RealmConfig"
_TOKENIZER_FOR_DOC
=
"RealmTokenizer"
REALM_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"
qqaatw
/realm-cc-news-pretrained-embedder"
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
,
"
qqaatw
/realm-orqa-nq-openqa"
,
"
qqaatw
/realm-orqa-nq-reader"
,
"
qqaatw
/realm-orqa-wq-openqa"
,
"
qqaatw
/realm-orqa-wq-reader"
,
"
google
/realm-cc-news-pretrained-embedder"
,
"
google
/realm-cc-news-pretrained-encoder"
,
"
google
/realm-cc-news-pretrained-scorer"
,
"
google
/realm-cc-news-pretrained-openqa"
,
"
google
/realm-orqa-nq-openqa"
,
"
google
/realm-orqa-nq-reader"
,
"
google
/realm-orqa-wq-openqa"
,
"
google
/realm-orqa-wq-reader"
,
# See all REALM models at https://huggingface.co/models?filter=realm
]
...
...
@@ -1180,8 +1180,8 @@ class RealmEmbedder(RealmPreTrainedModel):
>>> from transformers import RealmTokenizer, RealmEmbedder
>>> import torch
>>> tokenizer = RealmTokenizer.from_pretrained("
qqaatw
/realm-cc-news-pretrained-embedder")
>>> model = RealmEmbedder.from_pretrained("
qqaatw
/realm-cc-news-pretrained-embedder")
>>> tokenizer = RealmTokenizer.from_pretrained("
google
/realm-cc-news-pretrained-embedder")
>>> model = RealmEmbedder.from_pretrained("
google
/realm-cc-news-pretrained-embedder")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
...
...
@@ -1293,8 +1293,8 @@ class RealmScorer(RealmPreTrainedModel):
>>> import torch
>>> from transformers import RealmTokenizer, RealmScorer
>>> tokenizer = RealmTokenizer.from_pretrained("
qqaatw
/realm-cc-news-pretrained-scorer")
>>> model = RealmScorer.from_pretrained("
qqaatw
/realm-cc-news-pretrained-scorer", num_candidates=2)
>>> tokenizer = RealmTokenizer.from_pretrained("
google
/realm-cc-news-pretrained-scorer")
>>> model = RealmScorer.from_pretrained("
google
/realm-cc-news-pretrained-scorer", num_candidates=2)
>>> # batch_size = 2, num_candidates = 2
>>> input_texts = ["How are you?", "What is the item in the picture?"]
...
...
@@ -1433,9 +1433,9 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
>>> import torch
>>> from transformers import RealmTokenizer, RealmKnowledgeAugEncoder
>>> tokenizer = RealmTokenizer.from_pretrained("
qqaatw
/realm-cc-news-pretrained-encoder")
>>> tokenizer = RealmTokenizer.from_pretrained("
google
/realm-cc-news-pretrained-encoder")
>>> model = RealmKnowledgeAugEncoder.from_pretrained(
... "
qqaatw
/realm-cc-news-pretrained-encoder", num_candidates=2
... "
google
/realm-cc-news-pretrained-encoder", num_candidates=2
... )
>>> # batch_size = 2, num_candidates = 2
...
...
@@ -1761,9 +1761,9 @@ class RealmForOpenQA(RealmPreTrainedModel):
>>> import torch
>>> from transformers import RealmForOpenQA, RealmRetriever, RealmTokenizer
>>> retriever = RealmRetriever.from_pretrained("
qqaatw
/realm-orqa-nq-openqa")
>>> tokenizer = RealmTokenizer.from_pretrained("
qqaatw
/realm-orqa-nq-openqa")
>>> model = RealmForOpenQA.from_pretrained("
qqaatw
/realm-orqa-nq-openqa", retriever=retriever)
>>> retriever = RealmRetriever.from_pretrained("
google
/realm-orqa-nq-openqa")
>>> tokenizer = RealmTokenizer.from_pretrained("
google
/realm-orqa-nq-openqa")
>>> model = RealmForOpenQA.from_pretrained("
google
/realm-orqa-nq-openqa", retriever=retriever)
>>> question = "Who is the pioneer in modern computer science?"
>>> question_ids = tokenizer([question], return_tensors="pt")
...
...
src/transformers/models/realm/tokenization_realm.py
View file @
3385ca25
...
...
@@ -31,37 +31,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-openqa/resolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-nq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-reader/resolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-openqa/resolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-wq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-reader/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
,
"
google
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-nq-openqa/resolve/main/vocab.txt"
,
"
google
/realm-orqa-nq-reader"
:
"https://huggingface.co/
google
/realm-orqa-nq-reader/resolve/main/vocab.txt"
,
"
google
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-wq-openqa/resolve/main/vocab.txt"
,
"
google
/realm-orqa-wq-reader"
:
"https://huggingface.co/
google
/realm-orqa-wq-reader/resolve/main/vocab.txt"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
512
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
512
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
512
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
512
,
"
qqaatw
/realm-orqa-nq-openqa"
:
512
,
"
qqaatw
/realm-orqa-nq-reader"
:
512
,
"
qqaatw
/realm-orqa-wq-openqa"
:
512
,
"
qqaatw
/realm-orqa-wq-reader"
:
512
,
"
google
/realm-cc-news-pretrained-embedder"
:
512
,
"
google
/realm-cc-news-pretrained-encoder"
:
512
,
"
google
/realm-cc-news-pretrained-scorer"
:
512
,
"
google
/realm-cc-news-pretrained-openqa"
:
512
,
"
google
/realm-orqa-nq-openqa"
:
512
,
"
google
/realm-orqa-nq-reader"
:
512
,
"
google
/realm-orqa-wq-openqa"
:
512
,
"
google
/realm-orqa-wq-reader"
:
512
,
}
PRETRAINED_INIT_CONFIGURATION
=
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-nq-openqa"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-nq-reader"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-wq-openqa"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-wq-reader"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-embedder"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-encoder"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-scorer"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-openqa"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-nq-openqa"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-nq-reader"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-wq-openqa"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-wq-reader"
:
{
"do_lower_case"
:
True
},
}
...
...
@@ -252,7 +252,7 @@ class RealmTokenizer(PreTrainedTokenizer):
>>> # batch_size = 2, num_candidates = 2
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
>>> tokenizer = RealmTokenizer.from_pretrained("
qqaatw
/realm-cc-news-pretrained-encoder")
>>> tokenizer = RealmTokenizer.from_pretrained("
google
/realm-cc-news-pretrained-encoder")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
```"""
...
...
src/transformers/models/realm/tokenization_realm_fast.py
View file @
3385ca25
...
...
@@ -32,47 +32,47 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-openqa/resolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-nq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-reader/resolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-openqa/resolve/main/vocab.txt"
,
"
qqaatw
/realm-orqa-wq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-reader/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt"
,
"
google
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt"
,
"
google
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-nq-openqa/resolve/main/vocab.txt"
,
"
google
/realm-orqa-nq-reader"
:
"https://huggingface.co/
google
/realm-orqa-nq-reader/resolve/main/vocab.txt"
,
"
google
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-wq-openqa/resolve/main/vocab.txt"
,
"
google
/realm-orqa-wq-reader"
:
"https://huggingface.co/
google
/realm-orqa-wq-reader/resolve/main/vocab.txt"
,
},
"tokenizer_file"
:
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont"
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json"
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json"
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
qqaatw
/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json"
,
"
qqaatw
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-openqa/resolve/main/tokenizer.json"
,
"
qqaatw
/realm-orqa-nq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-nq-reader/resolve/main/tokenizer.json"
,
"
qqaatw
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-openqa/resolve/main/tokenizer.json"
,
"
qqaatw
/realm-orqa-wq-reader"
:
"https://huggingface.co/
qqaatw
/realm-orqa-wq-reader/resolve/main/tokenizer.json"
,
"
google
/realm-cc-news-pretrained-embedder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont"
,
"
google
/realm-cc-news-pretrained-encoder"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json"
,
"
google
/realm-cc-news-pretrained-scorer"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json"
,
"
google
/realm-cc-news-pretrained-openqa"
:
"https://huggingface.co/
google
/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json"
,
"
google
/realm-orqa-nq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-nq-openqa/resolve/main/tokenizer.json"
,
"
google
/realm-orqa-nq-reader"
:
"https://huggingface.co/
google
/realm-orqa-nq-reader/resolve/main/tokenizer.json"
,
"
google
/realm-orqa-wq-openqa"
:
"https://huggingface.co/
google
/realm-orqa-wq-openqa/resolve/main/tokenizer.json"
,
"
google
/realm-orqa-wq-reader"
:
"https://huggingface.co/
google
/realm-orqa-wq-reader/resolve/main/tokenizer.json"
,
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
512
,
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
512
,
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
512
,
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
512
,
"
qqaatw
/realm-orqa-nq-openqa"
:
512
,
"
qqaatw
/realm-orqa-nq-reader"
:
512
,
"
qqaatw
/realm-orqa-wq-openqa"
:
512
,
"
qqaatw
/realm-orqa-wq-reader"
:
512
,
"
google
/realm-cc-news-pretrained-embedder"
:
512
,
"
google
/realm-cc-news-pretrained-encoder"
:
512
,
"
google
/realm-cc-news-pretrained-scorer"
:
512
,
"
google
/realm-cc-news-pretrained-openqa"
:
512
,
"
google
/realm-orqa-nq-openqa"
:
512
,
"
google
/realm-orqa-nq-reader"
:
512
,
"
google
/realm-orqa-wq-openqa"
:
512
,
"
google
/realm-orqa-wq-reader"
:
512
,
}
PRETRAINED_INIT_CONFIGURATION
=
{
"
qqaatw
/realm-cc-news-pretrained-embedder"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-cc-news-pretrained-encoder"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-cc-news-pretrained-scorer"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-cc-news-pretrained-openqa"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-nq-openqa"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-nq-reader"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-wq-openqa"
:
{
"do_lower_case"
:
True
},
"
qqaatw
/realm-orqa-wq-reader"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-embedder"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-encoder"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-scorer"
:
{
"do_lower_case"
:
True
},
"
google
/realm-cc-news-pretrained-openqa"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-nq-openqa"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-nq-reader"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-wq-openqa"
:
{
"do_lower_case"
:
True
},
"
google
/realm-orqa-wq-reader"
:
{
"do_lower_case"
:
True
},
}
...
...
@@ -200,7 +200,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
>>> # batch_size = 2, num_candidates = 2
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
>>> tokenizer = RealmTokenizerFast.from_pretrained("
qqaatw
/realm-cc-news-pretrained-encoder")
>>> tokenizer = RealmTokenizerFast.from_pretrained("
google
/realm-cc-news-pretrained-encoder")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
```"""
...
...
tests/test_modeling_realm.py
View file @
3385ca25
...
...
@@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
input_ids
,
token_type_ids
,
input_mask
,
scorer_encoder_inputs
=
inputs
[
0
:
4
]
config
.
return_dict
=
True
tokenizer
=
RealmTokenizer
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-openqa"
)
tokenizer
=
RealmTokenizer
.
from_pretrained
(
"
google
/realm-orqa-nq-openqa"
)
# RealmKnowledgeAugEncoder training
model
=
RealmKnowledgeAugEncoder
(
config
)
...
...
@@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
@
slow
def
test_embedder_from_pretrained
(
self
):
model
=
RealmEmbedder
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-embedder"
)
model
=
RealmEmbedder
.
from_pretrained
(
"
google
/realm-cc-news-pretrained-embedder"
)
self
.
assertIsNotNone
(
model
)
@
slow
def
test_encoder_from_pretrained
(
self
):
model
=
RealmKnowledgeAugEncoder
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-encoder"
)
model
=
RealmKnowledgeAugEncoder
.
from_pretrained
(
"
google
/realm-cc-news-pretrained-encoder"
)
self
.
assertIsNotNone
(
model
)
@
slow
def
test_open_qa_from_pretrained
(
self
):
model
=
RealmForOpenQA
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-openqa"
)
model
=
RealmForOpenQA
.
from_pretrained
(
"
google
/realm-orqa-nq-openqa"
)
self
.
assertIsNotNone
(
model
)
@
slow
def
test_reader_from_pretrained
(
self
):
model
=
RealmReader
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-reader"
)
model
=
RealmReader
.
from_pretrained
(
"
google
/realm-orqa-nq-reader"
)
self
.
assertIsNotNone
(
model
)
@
slow
def
test_scorer_from_pretrained
(
self
):
model
=
RealmScorer
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-scorer"
)
model
=
RealmScorer
.
from_pretrained
(
"
google
/realm-cc-news-pretrained-scorer"
)
self
.
assertIsNotNone
(
model
)
...
...
@@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def
test_inference_embedder
(
self
):
retriever_projected_size
=
128
model
=
RealmEmbedder
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-embedder"
)
model
=
RealmEmbedder
.
from_pretrained
(
"
google
/realm-cc-news-pretrained-embedder"
)
input_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
]])
output
=
model
(
input_ids
)[
0
]
...
...
@@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
vocab_size
=
30522
model
=
RealmKnowledgeAugEncoder
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-encoder"
,
num_candidates
=
num_candidates
"
google
/realm-cc-news-pretrained-encoder"
,
num_candidates
=
num_candidates
)
input_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
]])
relevance_score
=
torch
.
tensor
([[
0.3
,
0.7
]],
dtype
=
torch
.
float32
)
...
...
@@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase):
config
=
RealmConfig
()
tokenizer
=
RealmTokenizer
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-openqa"
)
retriever
=
RealmRetriever
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-openqa"
)
tokenizer
=
RealmTokenizer
.
from_pretrained
(
"
google
/realm-orqa-nq-openqa"
)
retriever
=
RealmRetriever
.
from_pretrained
(
"
google
/realm-orqa-nq-openqa"
)
model
=
RealmForOpenQA
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-openqa"
,
"
google
/realm-orqa-nq-openqa"
,
retriever
=
retriever
,
config
=
config
,
)
...
...
@@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
@
slow
def
test_inference_reader
(
self
):
config
=
RealmConfig
(
reader_beam_size
=
2
,
max_span_width
=
3
)
model
=
RealmReader
.
from_pretrained
(
"
qqaatw
/realm-orqa-nq-reader"
,
config
=
config
)
model
=
RealmReader
.
from_pretrained
(
"
google
/realm-orqa-nq-reader"
,
config
=
config
)
concat_input_ids
=
torch
.
arange
(
10
).
view
((
2
,
5
))
concat_token_type_ids
=
torch
.
tensor
([[
0
,
0
,
1
,
1
,
1
],
[
0
,
0
,
1
,
1
,
1
]],
dtype
=
torch
.
int64
)
...
...
@@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
def
test_inference_scorer
(
self
):
num_candidates
=
2
model
=
RealmScorer
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-scorer"
,
num_candidates
=
num_candidates
)
model
=
RealmScorer
.
from_pretrained
(
"
google
/realm-cc-news-pretrained-scorer"
,
num_candidates
=
num_candidates
)
input_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
]])
candidate_input_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
],
[
6
,
7
,
8
,
9
,
10
,
11
]])
...
...
tests/test_retrieval_realm.py
View file @
3385ca25
...
...
@@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase):
mock_hf_hub_download
.
return_value
=
os
.
path
.
join
(
os
.
path
.
join
(
self
.
tmpdirname
,
"realm_block_records"
),
_REALM_BLOCK_RECORDS_FILENAME
)
retriever
=
RealmRetriever
.
from_pretrained
(
"
qqaatw
/realm-cc-news-pretrained-openqa"
)
retriever
=
RealmRetriever
.
from_pretrained
(
"
google
/realm-cc-news-pretrained-openqa"
)
self
.
assertEqual
(
retriever
.
block_records
[
0
],
b
"This is the first record"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment