Unverified Commit 9d37c56b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] - Cache hidden states and buckets to speed up inference (#5578)

* fix merge rebase

* add intermediate reformer code

* save intermediate caching results

* save intermediate

* save intermediate results

* save intermediate

* upload next step

* fix generate tests

* make tests work

* add named tuple output

* Apply suggestions from code review

* fix use_cache for False case

* fix tensor to gpu

* fix tensor to gpu

* refactor

* refactor and make style
parent 0b6c255a
This diff is collapsed.
...@@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput): ...@@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput):
@dataclass @dataclass
class XLNetLMHeadModelOutput(ModelOutput): class XLNetLMHeadModelOutput(ModelOutput):
""" """
Output type of :class:`~transformers.XLNetModel`. Output type of :class:`~transformers.XLNetLMHeadModel`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
...@@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput): ...@@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput):
@dataclass @dataclass
class XLNetForSequenceClassificationOutput(ModelOutput): class XLNetForSequenceClassificationOutput(ModelOutput):
""" """
Base class for outputs of sentence classification models. Output type of :class:`~transformers.XLNetForSequenceClassification`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
...@@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput): ...@@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput):
@dataclass @dataclass
class XLNetForTokenClassificationOutput(ModelOutput): class XLNetForTokenClassificationOutput(ModelOutput):
""" """
Base class for outputs of token classification models. Output type of :class:`~transformers.XLNetForTokenClassificationOutput`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
......
...@@ -181,8 +181,8 @@ class ReformerModelTester: ...@@ -181,8 +181,8 @@ class ReformerModelTester:
model = ReformerModel(config=config) model = ReformerModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask) sequence_output, _ = model(input_ids, attention_mask=input_mask)
(sequence_output,) = model(input_ids) sequence_output, _ = model(input_ids)
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
...@@ -193,17 +193,21 @@ class ReformerModelTester: ...@@ -193,17 +193,21 @@ class ReformerModelTester:
) )
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config) config.is_decoder = False
config.lsh_num_chunks_after = 1
model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward() loss.backward()
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
config.lsh_num_chunks_after = 0
config.is_decoder = True
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
...@@ -332,9 +336,11 @@ class ReformerModelTester: ...@@ -332,9 +336,11 @@ class ReformerModelTester:
config.hidden_dropout_prob = 0 config.hidden_dropout_prob = 0
config.local_attention_probs_dropout_prob = 0 config.local_attention_probs_dropout_prob = 0
config.lsh_attention_probs_dropout_prob = 0 config.lsh_attention_probs_dropout_prob = 0
config.lsh_num_chunks_after = 1
config.is_decoder = False
torch.manual_seed(0) torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
model.zero_grad() model.zero_grad()
...@@ -348,7 +354,7 @@ class ReformerModelTester: ...@@ -348,7 +354,7 @@ class ReformerModelTester:
config.chunk_size_feed_forward = 1 config.chunk_size_feed_forward = 1
torch.manual_seed(0) torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
model.zero_grad() model.zero_grad()
...@@ -405,7 +411,22 @@ class ReformerModelTester: ...@@ -405,7 +411,22 @@ class ReformerModelTester:
output = model(input_ids, attention_mask=input_mask)[0] output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
config.bos_token_id = 0
config.eos_token_id = None
config.max_length = 20
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output = model.generate()
self.parent.assertIsNotNone(output)
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.half() model.half()
...@@ -418,13 +439,15 @@ class ReformerModelTester: ...@@ -418,13 +439,15 @@ class ReformerModelTester:
# force chunk length to be bigger than input_ids # force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1] config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1] config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config) config.lsh_num_chunks_after = 1
config.is_decoder = False
model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0] output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config) model = ReformerForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -440,6 +463,33 @@ class ReformerModelTester: ...@@ -440,6 +463,33 @@ class ReformerModelTester:
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_before = 1
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
input_ids_first = input_ids[:, :-1]
input_ids_second = input_ids[:, -1:]
# return saved cache
_, past_buckets_states = model(input_ids_first, use_cache=True)
# calculate last output with and without cache
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
outputs_without_cache = model(input_ids)[0][:, -1]
# select random slice idx
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
# outputs should be similar within range
self.parent.assertTrue(
torch.allclose(
outputs_with_cache[:, 0, random_slice_idx], outputs_without_cache[:, random_slice_idx], atol=1e-2
)
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, choice_labels) = config_and_inputs (config, input_ids, input_mask, choice_labels) = config_and_inputs
...@@ -509,6 +559,18 @@ class ReformerTesterMixin: ...@@ -509,6 +559,18 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs) self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
def test_reformer_qa_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_for_question_answering(*config_and_inputs)
def test_reformer_cached_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_past_buckets_states(*config_and_inputs)
def test_reformer_cached_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_generate(*config_and_inputs)
@slow @slow
def test_dropout_random_seed_is_changing(self): def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
...@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T ...@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"num_buckets": 2, "num_buckets": 2,
"num_hashes": 4, "num_hashes": 4,
"lsh_attn_chunk_length": 4, "lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 2, "lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 3, "lsh_num_chunks_after": 0,
"chunk_size_lm_head": 5, "chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6, "chunk_size_feed_forward": 6,
"feed_forward_size": 32, "feed_forward_size": 32,
...@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T ...@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"axial_pos_embds": True, "axial_pos_embds": True,
"axial_pos_shape": [4, 8], "axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48], "axial_pos_embds_dim": [16, 48],
"attn_layers": ["lsh", "lsh", "lsh", "lsh"], # sanotheu
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"attn_layers": ["lsh"],
"pad_token_id": 0, "pad_token_id": 0,
"eos_token_id": 2, "eos_token_id": 2,
"scope": None, "scope": None,
...@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase):
output_ids = model.generate( output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8 input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
) )
output_text = tokenizer.decode(output_ids[0]) output = tokenizer.decode(output_ids[0])
self.assertEqual( self.assertEqual(
output_text, output,
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst", "A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
) )
@slow
def test_pretrained_generate_use_cache_equality(self):
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model.eval()
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids_with_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=False)
output_ids_without_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=True)
output_with_cache = tokenizer.decode(output_ids_with_cache[0])
output_without_cache = tokenizer.decode(output_ids_without_cache[0])
self.assertEqual(output_with_cache, output_without_cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment