"tests/models/auto/test_modeling_tf_auto.py" did not exist on "13deb95a405bbd1037ad233c692d7fd1de9d31e3"
Unverified Commit 7b3bd1f2 authored by Li-Huai (Allan) Lin's avatar Li-Huai (Allan) Lin Committed by GitHub
Browse files

Fix and improve REALM fine-tuning (#15297)

* Draft

* Add test

* Update src/transformers/models/realm/modeling_realm.py

* Apply suggestion

* Add block_mask

* Update

* Update

* Add block_embedding_to

* Remove no_grad

* Use AutoTokenizer

* Remove model.to overridding
parent 439de3f7
......@@ -81,4 +81,5 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi
## RealmForOpenQA
[[autodoc]] RealmForOpenQA
- block_embedding_to
- forward
\ No newline at end of file
......@@ -48,6 +48,7 @@ else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
......
......@@ -836,13 +836,13 @@ class RealmReaderProjection(nn.Module):
self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
self.relu = nn.ReLU()
def forward(self, hidden_states, token_type_ids):
def forward(self, hidden_states, block_mask):
def span_candidates(masks):
"""
Generate span candidates.
Args:
masks: <int32> [num_retrievals, max_sequence_len]
masks: <bool> [num_retrievals, max_sequence_len]
Returns:
starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
......@@ -875,8 +875,7 @@ class RealmReaderProjection(nn.Module):
hidden_states = self.dense_intermediate(hidden_states)
# [reader_beam_size, max_sequence_len, span_hidden_size]
start_projection, end_projection = hidden_states.chunk(2, dim=-1)
block_mask = token_type_ids.detach().clone()
block_mask[:, -1] = 0
candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
......@@ -1543,6 +1542,7 @@ class RealmReader(RealmPreTrainedModel):
head_mask=None,
inputs_embeds=None,
relevance_score=None,
block_mask=None,
start_positions=None,
end_positions=None,
has_answers=None,
......@@ -1552,12 +1552,15 @@ class RealmReader(RealmPreTrainedModel):
):
r"""
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
Relevance score, which must be specified if you want to compute the marginal log loss.
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Relevance score, which must be specified if you want to compute the logits and marginal log loss.
block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
loss.
start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
......@@ -1570,8 +1573,8 @@ class RealmReader(RealmPreTrainedModel):
if relevance_score is None:
raise ValueError("You have to specify `relevance_score` to calculate logits and loss.")
if token_type_ids is None:
raise ValueError("You have to specify `token_type_ids` to separate question block and evidence block.")
if block_mask is None:
raise ValueError("You have to specify `block_mask` to separate question block and evidence block.")
if token_type_ids.size(1) < self.config.max_span_width:
raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.")
outputs = self.realm(
......@@ -1590,7 +1593,9 @@ class RealmReader(RealmPreTrainedModel):
sequence_output = outputs[0]
# [reader_beam_size, num_candidates], [num_candidates], [num_candidates]
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(sequence_output, token_type_ids)
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(
sequence_output, block_mask[0 : self.config.reader_beam_size]
)
# [searcher_beam_size, 1]
retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1)
# [reader_beam_size, num_candidates]
......@@ -1737,11 +1742,21 @@ class RealmForOpenQA(RealmPreTrainedModel):
self.post_init()
@property
def beam_size(self):
def searcher_beam_size(self):
if self.training:
return self.config.searcher_beam_size
return self.config.reader_beam_size
def block_embedding_to(self, device):
"""Send `self.block_emb` to a specific device.
Args:
device (`str` or `torch.device`):
The device to which `self.block_emb` will be sent.
"""
self.block_emb = self.block_emb.to(device)
@add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
def forward(
......@@ -1787,36 +1802,37 @@ class RealmForOpenQA(RealmPreTrainedModel):
question_outputs = self.embedder(
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
)
# [1, projection_size]
question_projection = question_outputs[0]
# CPU computation starts.
# [1, block_emb_size]
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection)
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device))
# [1, searcher_beam_size]
_, retrieved_block_ids = torch.topk(batch_scores, k=self.beam_size, dim=-1)
_, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1)
# [searcher_beam_size]
# Must convert to cpu tensor for subsequent numpy operations
retrieved_block_ids = retrieved_block_ids.squeeze().cpu()
retrieved_block_ids = retrieved_block_ids.squeeze()
# [searcher_beam_size, projection_size]
retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids)
# CPU computation ends.
# Retrieve possible answers
has_answers, start_pos, end_pos, concat_inputs = self.retriever(
retrieved_block_ids, input_ids, answer_ids, max_length=self.config.reader_seq_len
retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len
)
concat_inputs = concat_inputs.to(self.reader.device)
block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device)
block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool))
if has_answers is not None:
has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device)
start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device)
end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device)
concat_inputs = concat_inputs.to(self.reader.device)
# [searcher_beam_size, projection_size]
retrieved_block_emb = torch.index_select(
self.block_emb, dim=0, index=retrieved_block_ids.to(self.block_emb.device)
)
# [searcher_beam_size]
retrieved_logits = torch.einsum(
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(question_projection.device)
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device)
)
reader_output = self.reader(
......@@ -1824,6 +1840,7 @@ class RealmForOpenQA(RealmPreTrainedModel):
attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size],
token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size],
relevance_score=retrieved_logits,
block_mask=block_mask,
has_answers=has_answers,
start_positions=start_pos,
end_positions=end_pos,
......
......@@ -20,9 +20,9 @@ from typing import Optional, Union
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from ...utils import logging
from .tokenization_realm import RealmTokenizer
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
......@@ -97,7 +97,9 @@ class RealmRetriever:
text.append(question)
text_pair.append(retrieved_block.decode())
concat_inputs = self.tokenizer(text, text_pair, padding=True, truncation=True, max_length=max_length)
concat_inputs = self.tokenizer(
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
)
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
if answer_ids is not None:
......@@ -115,7 +117,7 @@ class RealmRetriever:
)
block_records = np.load(block_records_path, allow_pickle=True)
tokenizer = RealmTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
return cls(block_records, tokenizer)
......@@ -133,13 +135,15 @@ class RealmRetriever:
max_answers = 0
for input_id in concat_inputs.input_ids:
input_id_list = input_id.tolist()
# Check answers between two [SEP] tokens
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
start_pos.append([])
end_pos.append([])
input_id_list = input_id.tolist()
# Checking answers after the [SEP] token
sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
for answer in answer_ids:
for idx in range(sep_idx, len(input_id)):
for idx in range(first_sep_idx + 1, second_sep_idx):
if answer[0] == input_id_list[idx]:
if input_id_list[idx : idx + len(answer)] == answer:
start_pos[-1].append(idx)
......@@ -158,5 +162,4 @@ class RealmRetriever:
padded = [-1] * (max_answers - len(start_pos_))
start_pos_ += padded
end_pos_ += padded
return has_answers, start_pos, end_pos
......@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_embedder(*config_and_inputs)
self.model_tester.create_and_check_encoder(*config_and_inputs)
def test_retriever(self):
def test_scorer(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_scorer(*config_and_inputs)
......@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).reader_output.loss
loss.backward()
# Test model.block_embedding_to
device = torch.device("cpu")
model.block_embedding_to(device)
loss = model(**inputs).reader_output.loss
loss.backward()
self.assertEqual(model.block_emb.device.type, device.type)
@slow
def test_embedder_from_pretrained(self):
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
......@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase):
concat_input_ids = torch.arange(10).view((2, 5))
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
output = model(
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True
concat_input_ids,
token_type_ids=concat_token_type_ids,
relevance_score=relevance_score,
block_mask=concat_block_mask,
return_dict=True,
)
block_idx_expected_shape = torch.Size(())
......
......@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase):
b"This is the third record",
b"This is the fourth record",
b"This is the fifth record",
b"This is a longer longer longer record",
],
dtype=np.object,
)
......@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase):
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
self.assertEqual(
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
......@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer
retrieved_block_ids = np.array([0, 3], dtype=np.long)
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth"],
["the fourth", "longer longer"],
add_special_tokens=False,
return_token_type_ids=False,
return_attention_mask=False,
......@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase):
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
)
self.assertEqual([False, True], has_answers)
self.assertEqual([[-1], [6]], start_pos)
self.assertEqual([[-1], [7]], end_pos)
self.assertEqual([False, True, True], has_answers)
self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)
def test_save_load_pretrained(self):
retriever = self.get_dummy_retriever()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment