Unverified Commit 5ff2b3a5 authored by Geewook Kim's avatar Geewook Kim Committed by GitHub
Browse files

Merge pull request #56 from SamSamhuns/fix_for_new_transformers_lib_ver

Change model_kwargs argument to encoder_outputs to support transformers>=4.22.1
parents 362f844b 68b30a37
...@@ -206,7 +206,7 @@ class BARTDecoder(nn.Module): ...@@ -206,7 +206,7 @@ class BARTDecoder(nn.Module):
if newly_added_num > 0: if newly_added_num > 0:
self.model.resize_token_embeddings(len(self.tokenizer)) self.model.resize_token_embeddings(len(self.tokenizer))
def prepare_inputs_for_inference(self, input_ids: torch.Tensor, past=None, use_cache: bool = None, **model_kwargs): def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
""" """
Args: Args:
input_ids: (batch_size, sequence_lenth) input_ids: (batch_size, sequence_lenth)
...@@ -223,7 +223,7 @@ class BARTDecoder(nn.Module): ...@@ -223,7 +223,7 @@ class BARTDecoder(nn.Module):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"past_key_values": past, "past_key_values": past,
"use_cache": use_cache, "use_cache": use_cache,
"encoder_hidden_states": model_kwargs["encoder_outputs"].last_hidden_state, "encoder_hidden_states": encoder_outputs.last_hidden_state,
} }
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment