Unverified Commit a0e94bf1 authored by Geewook Kim's avatar Geewook Kim Committed by GitHub
Browse files

Merge pull request #165 from dotneet/fix/past_key_values

supports latest transformers
parents 217cffb1 48479fe8
......@@ -206,7 +206,7 @@ class BARTDecoder(nn.Module):
if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer))
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past_key_values=None, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
"""
Args:
input_ids: (batch_size, sequence_lenth)
......@@ -215,13 +215,16 @@ class BARTDecoder(nn.Module):
attention_mask: (batch_size, sequence_length)
encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
"""
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
# for compatibility with transformers==4.11.x
if past is not None:
past_key_values = past
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
if past_key_values is not None:
input_ids = input_ids[:, -1:]
output = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
"encoder_hidden_states": encoder_outputs.last_hidden_state,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment