Unverified Commit 9d37c56b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] - Cache hidden states and buckets to speed up inference (#5578)

* fix merge rebase

* add intermediate reformer code

* save intermediate caching results

* save intermediate

* save intermediate results

* save intermediate

* upload next step

* fix generate tests

* make tests work

* add named tuple output

* Apply suggestions from code review

* fix use_cache for False case

* fix tensor to gpu

* fix tensor to gpu

* refactor

* refactor and make style
parent 0b6c255a
This diff is collapsed.
......@@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput):
@dataclass
class XLNetLMHeadModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.XLNetModel`.
Output type of :class:`~transformers.XLNetLMHeadModel`.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
......@@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput):
@dataclass
class XLNetForSequenceClassificationOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Output type of :class:`~transformers.XLNetForSequenceClassification`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
......@@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput):
@dataclass
class XLNetForTokenClassificationOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Output type of :class:`~transformers.XLNetForTokenClassificationOutput`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
......
......@@ -181,8 +181,8 @@ class ReformerModelTester:
model = ReformerModel(config=config)
model.to(torch_device)
model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask)
(sequence_output,) = model(input_ids)
sequence_output, _ = model(input_ids, attention_mask=input_mask)
sequence_output, _ = model(input_ids)
result = {
"sequence_output": sequence_output,
......@@ -193,17 +193,21 @@ class ReformerModelTester:
)
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config)
config.is_decoder = False
config.lsh_num_chunks_after = 1
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward()
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
config.lsh_num_chunks_after = 0
config.is_decoder = True
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
......@@ -332,9 +336,11 @@ class ReformerModelTester:
config.hidden_dropout_prob = 0
config.local_attention_probs_dropout_prob = 0
config.lsh_attention_probs_dropout_prob = 0
config.lsh_num_chunks_after = 1
config.is_decoder = False
torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config)
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.train()
model.zero_grad()
......@@ -348,7 +354,7 @@ class ReformerModelTester:
config.chunk_size_feed_forward = 1
torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config)
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.train()
model.zero_grad()
......@@ -405,7 +411,22 @@ class ReformerModelTester:
output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
config.bos_token_id = 0
config.eos_token_id = None
config.max_length = 20
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output = model.generate()
self.parent.assertIsNotNone(output)
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.half()
......@@ -418,13 +439,15 @@ class ReformerModelTester:
# force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config)
config.lsh_num_chunks_after = 1
config.is_decoder = False
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
......@@ -440,6 +463,33 @@ class ReformerModelTester:
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_before = 1
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
input_ids_first = input_ids[:, :-1]
input_ids_second = input_ids[:, -1:]
# return saved cache
_, past_buckets_states = model(input_ids_first, use_cache=True)
# calculate last output with and without cache
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
outputs_without_cache = model(input_ids)[0][:, -1]
# select random slice idx
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
# outputs should be similar within range
self.parent.assertTrue(
torch.allclose(
outputs_with_cache[:, 0, random_slice_idx], outputs_without_cache[:, random_slice_idx], atol=1e-2
)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, choice_labels) = config_and_inputs
......@@ -509,6 +559,18 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
def test_reformer_qa_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_for_question_answering(*config_and_inputs)
def test_reformer_cached_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_past_buckets_states(*config_and_inputs)
def test_reformer_cached_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_generate(*config_and_inputs)
@slow
def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
......@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"num_buckets": 2,
"num_hashes": 4,
"lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 2,
"lsh_num_chunks_after": 3,
"lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 0,
"chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6,
"feed_forward_size": 32,
......@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"axial_pos_embds": True,
"axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48],
"attn_layers": ["lsh", "lsh", "lsh", "lsh"],
# sanotheu
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"attn_layers": ["lsh"],
"pad_token_id": 0,
"eos_token_id": 2,
"scope": None,
......@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase):
output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
)
output_text = tokenizer.decode(output_ids[0])
output = tokenizer.decode(output_ids[0])
self.assertEqual(
output_text,
output,
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
)
@slow
def test_pretrained_generate_use_cache_equality(self):
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model.eval()
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids_with_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=False)
output_ids_without_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=True)
output_with_cache = tokenizer.decode(output_ids_with_cache[0])
output_without_cache = tokenizer.decode(output_ids_without_cache[0])
self.assertEqual(output_with_cache, output_without_cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment