Unverified Commit 7b3bd1f2 authored by Li-Huai (Allan) Lin's avatar Li-Huai (Allan) Lin Committed by GitHub
Browse files

Fix and improve REALM fine-tuning (#15297)

* Draft

* Add test

* Update src/transformers/models/realm/modeling_realm.py

* Apply suggestion

* Add block_mask

* Update

* Update

* Add block_embedding_to

* Remove no_grad

* Use AutoTokenizer

* Remove model.to overridding
parent 439de3f7
...@@ -81,4 +81,5 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi ...@@ -81,4 +81,5 @@ This model was contributed by [qqaatw](https://huggingface.co/qqaatw). The origi
## RealmForOpenQA ## RealmForOpenQA
[[autodoc]] RealmForOpenQA [[autodoc]] RealmForOpenQA
- block_embedding_to
- forward - forward
\ No newline at end of file
...@@ -48,6 +48,7 @@ else: ...@@ -48,6 +48,7 @@ else:
TOKENIZER_MAPPING_NAMES = OrderedDict( TOKENIZER_MAPPING_NAMES = OrderedDict(
[ [
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
......
...@@ -836,13 +836,13 @@ class RealmReaderProjection(nn.Module): ...@@ -836,13 +836,13 @@ class RealmReaderProjection(nn.Module):
self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps) self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, hidden_states, token_type_ids): def forward(self, hidden_states, block_mask):
def span_candidates(masks): def span_candidates(masks):
""" """
Generate span candidates. Generate span candidates.
Args: Args:
masks: <int32> [num_retrievals, max_sequence_len] masks: <bool> [num_retrievals, max_sequence_len]
Returns: Returns:
starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans] starts: <int32> [num_spans] ends: <int32> [num_spans] span_masks: <int32> [num_retrievals, num_spans]
...@@ -875,8 +875,7 @@ class RealmReaderProjection(nn.Module): ...@@ -875,8 +875,7 @@ class RealmReaderProjection(nn.Module):
hidden_states = self.dense_intermediate(hidden_states) hidden_states = self.dense_intermediate(hidden_states)
# [reader_beam_size, max_sequence_len, span_hidden_size] # [reader_beam_size, max_sequence_len, span_hidden_size]
start_projection, end_projection = hidden_states.chunk(2, dim=-1) start_projection, end_projection = hidden_states.chunk(2, dim=-1)
block_mask = token_type_ids.detach().clone()
block_mask[:, -1] = 0
candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask) candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts) candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
...@@ -1543,6 +1542,7 @@ class RealmReader(RealmPreTrainedModel): ...@@ -1543,6 +1542,7 @@ class RealmReader(RealmPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
relevance_score=None, relevance_score=None,
block_mask=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
has_answers=None, has_answers=None,
...@@ -1552,12 +1552,15 @@ class RealmReader(RealmPreTrainedModel): ...@@ -1552,12 +1552,15 @@ class RealmReader(RealmPreTrainedModel):
): ):
r""" r"""
relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*): relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
Relevance score, which must be specified if you want to compute the marginal log loss. Relevance score, which must be specified if you want to compute the logits and marginal log loss.
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
loss.
start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss. Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
...@@ -1570,8 +1573,8 @@ class RealmReader(RealmPreTrainedModel): ...@@ -1570,8 +1573,8 @@ class RealmReader(RealmPreTrainedModel):
if relevance_score is None: if relevance_score is None:
raise ValueError("You have to specify `relevance_score` to calculate logits and loss.") raise ValueError("You have to specify `relevance_score` to calculate logits and loss.")
if token_type_ids is None: if block_mask is None:
raise ValueError("You have to specify `token_type_ids` to separate question block and evidence block.") raise ValueError("You have to specify `block_mask` to separate question block and evidence block.")
if token_type_ids.size(1) < self.config.max_span_width: if token_type_ids.size(1) < self.config.max_span_width:
raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.") raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.")
outputs = self.realm( outputs = self.realm(
...@@ -1590,7 +1593,9 @@ class RealmReader(RealmPreTrainedModel): ...@@ -1590,7 +1593,9 @@ class RealmReader(RealmPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
# [reader_beam_size, num_candidates], [num_candidates], [num_candidates] # [reader_beam_size, num_candidates], [num_candidates], [num_candidates]
reader_logits, candidate_starts, candidate_ends = self.qa_outputs(sequence_output, token_type_ids) reader_logits, candidate_starts, candidate_ends = self.qa_outputs(
sequence_output, block_mask[0 : self.config.reader_beam_size]
)
# [searcher_beam_size, 1] # [searcher_beam_size, 1]
retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1) retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1)
# [reader_beam_size, num_candidates] # [reader_beam_size, num_candidates]
...@@ -1737,11 +1742,21 @@ class RealmForOpenQA(RealmPreTrainedModel): ...@@ -1737,11 +1742,21 @@ class RealmForOpenQA(RealmPreTrainedModel):
self.post_init() self.post_init()
@property @property
def beam_size(self): def searcher_beam_size(self):
if self.training: if self.training:
return self.config.searcher_beam_size return self.config.searcher_beam_size
return self.config.reader_beam_size return self.config.reader_beam_size
def block_embedding_to(self, device):
"""Send `self.block_emb` to a specific device.
Args:
device (`str` or `torch.device`):
The device to which `self.block_emb` will be sent.
"""
self.block_emb = self.block_emb.to(device)
@add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length")) @add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
@replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1787,36 +1802,37 @@ class RealmForOpenQA(RealmPreTrainedModel): ...@@ -1787,36 +1802,37 @@ class RealmForOpenQA(RealmPreTrainedModel):
question_outputs = self.embedder( question_outputs = self.embedder(
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
) )
# [1, projection_size] # [1, projection_size]
question_projection = question_outputs[0] question_projection = question_outputs[0]
# CPU computation starts.
# [1, block_emb_size] # [1, block_emb_size]
batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection) batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device))
# [1, searcher_beam_size] # [1, searcher_beam_size]
_, retrieved_block_ids = torch.topk(batch_scores, k=self.beam_size, dim=-1) _, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1)
# [searcher_beam_size] # [searcher_beam_size]
# Must convert to cpu tensor for subsequent numpy operations retrieved_block_ids = retrieved_block_ids.squeeze()
retrieved_block_ids = retrieved_block_ids.squeeze().cpu() # [searcher_beam_size, projection_size]
retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids)
# CPU computation ends.
# Retrieve possible answers # Retrieve possible answers
has_answers, start_pos, end_pos, concat_inputs = self.retriever( has_answers, start_pos, end_pos, concat_inputs = self.retriever(
retrieved_block_ids, input_ids, answer_ids, max_length=self.config.reader_seq_len retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len
) )
concat_inputs = concat_inputs.to(self.reader.device)
block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device)
block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool))
if has_answers is not None: if has_answers is not None:
has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device) has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device)
start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device) start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device)
end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device) end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device)
concat_inputs = concat_inputs.to(self.reader.device)
# [searcher_beam_size, projection_size]
retrieved_block_emb = torch.index_select(
self.block_emb, dim=0, index=retrieved_block_ids.to(self.block_emb.device)
)
# [searcher_beam_size] # [searcher_beam_size]
retrieved_logits = torch.einsum( retrieved_logits = torch.einsum(
"D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(question_projection.device) "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device)
) )
reader_output = self.reader( reader_output = self.reader(
...@@ -1824,6 +1840,7 @@ class RealmForOpenQA(RealmPreTrainedModel): ...@@ -1824,6 +1840,7 @@ class RealmForOpenQA(RealmPreTrainedModel):
attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size], attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size],
token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size], token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size],
relevance_score=retrieved_logits, relevance_score=retrieved_logits,
block_mask=block_mask,
has_answers=has_answers, has_answers=has_answers,
start_positions=start_pos, start_positions=start_pos,
end_positions=end_pos, end_positions=end_pos,
......
...@@ -20,9 +20,9 @@ from typing import Optional, Union ...@@ -20,9 +20,9 @@ from typing import Optional, Union
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from ...utils import logging from ...utils import logging
from .tokenization_realm import RealmTokenizer
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy" _REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
...@@ -97,7 +97,9 @@ class RealmRetriever: ...@@ -97,7 +97,9 @@ class RealmRetriever:
text.append(question) text.append(question)
text_pair.append(retrieved_block.decode()) text_pair.append(retrieved_block.decode())
concat_inputs = self.tokenizer(text, text_pair, padding=True, truncation=True, max_length=max_length) concat_inputs = self.tokenizer(
text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
)
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors) concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
if answer_ids is not None: if answer_ids is not None:
...@@ -115,7 +117,7 @@ class RealmRetriever: ...@@ -115,7 +117,7 @@ class RealmRetriever:
) )
block_records = np.load(block_records_path, allow_pickle=True) block_records = np.load(block_records_path, allow_pickle=True)
tokenizer = RealmTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
return cls(block_records, tokenizer) return cls(block_records, tokenizer)
...@@ -133,13 +135,15 @@ class RealmRetriever: ...@@ -133,13 +135,15 @@ class RealmRetriever:
max_answers = 0 max_answers = 0
for input_id in concat_inputs.input_ids: for input_id in concat_inputs.input_ids:
input_id_list = input_id.tolist()
# Check answers between two [SEP] tokens
first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
start_pos.append([]) start_pos.append([])
end_pos.append([]) end_pos.append([])
input_id_list = input_id.tolist()
# Checking answers after the [SEP] token
sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
for answer in answer_ids: for answer in answer_ids:
for idx in range(sep_idx, len(input_id)): for idx in range(first_sep_idx + 1, second_sep_idx):
if answer[0] == input_id_list[idx]: if answer[0] == input_id_list[idx]:
if input_id_list[idx : idx + len(answer)] == answer: if input_id_list[idx : idx + len(answer)] == answer:
start_pos[-1].append(idx) start_pos[-1].append(idx)
...@@ -158,5 +162,4 @@ class RealmRetriever: ...@@ -158,5 +162,4 @@ class RealmRetriever:
padded = [-1] * (max_answers - len(start_pos_)) padded = [-1] * (max_answers - len(start_pos_))
start_pos_ += padded start_pos_ += padded
end_pos_ += padded end_pos_ += padded
return has_answers, start_pos, end_pos return has_answers, start_pos, end_pos
...@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -345,7 +345,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_embedder(*config_and_inputs) self.model_tester.create_and_check_embedder(*config_and_inputs)
self.model_tester.create_and_check_encoder(*config_and_inputs) self.model_tester.create_and_check_encoder(*config_and_inputs)
def test_retriever(self): def test_scorer(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_scorer(*config_and_inputs) self.model_tester.create_and_check_scorer(*config_and_inputs)
...@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -408,6 +408,13 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).reader_output.loss loss = model(**inputs).reader_output.loss
loss.backward() loss.backward()
# Test model.block_embedding_to
device = torch.device("cpu")
model.block_embedding_to(device)
loss = model(**inputs).reader_output.loss
loss.backward()
self.assertEqual(model.block_emb.device.type, device.type)
@slow @slow
def test_embedder_from_pretrained(self): def test_embedder_from_pretrained(self):
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder") model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
...@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase): ...@@ -506,10 +513,15 @@ class RealmModelIntegrationTest(unittest.TestCase):
concat_input_ids = torch.arange(10).view((2, 5)) concat_input_ids = torch.arange(10).view((2, 5))
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64) concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32) relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
output = model( output = model(
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True concat_input_ids,
token_type_ids=concat_token_type_ids,
relevance_score=relevance_score,
block_mask=concat_block_mask,
return_dict=True,
) )
block_idx_expected_shape = torch.Size(()) block_idx_expected_shape = torch.Size(())
......
...@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase): ...@@ -98,6 +98,7 @@ class RealmRetrieverTest(TestCase):
b"This is the third record", b"This is the third record",
b"This is the fourth record", b"This is the fourth record",
b"This is the fifth record", b"This is the fifth record",
b"This is a longer longer longer record",
], ],
dtype=np.object, dtype=np.object,
) )
...@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase): ...@@ -135,6 +136,7 @@ class RealmRetrieverTest(TestCase):
self.assertEqual(concat_inputs.input_ids.shape, (2, 10)) self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10)) self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10)) self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
self.assertEqual( self.assertEqual(
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]), tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"], ["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
...@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase): ...@@ -149,10 +151,10 @@ class RealmRetrieverTest(TestCase):
retriever = self.get_dummy_retriever() retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer tokenizer = retriever.tokenizer
retrieved_block_ids = np.array([0, 3], dtype=np.long) retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
question_input_ids = tokenizer(["Test question"]).input_ids question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer( answer_ids = tokenizer(
["the fourth"], ["the fourth", "longer longer"],
add_special_tokens=False, add_special_tokens=False,
return_token_type_ids=False, return_token_type_ids=False,
return_attention_mask=False, return_attention_mask=False,
...@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase): ...@@ -163,9 +165,9 @@ class RealmRetrieverTest(TestCase):
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np" retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
) )
self.assertEqual([False, True], has_answers) self.assertEqual([False, True, True], has_answers)
self.assertEqual([[-1], [6]], start_pos) self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
self.assertEqual([[-1], [7]], end_pos) self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)
def test_save_load_pretrained(self): def test_save_load_pretrained(self):
retriever = self.get_dummy_retriever() retriever = self.get_dummy_retriever()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment