Unverified Commit 1a5aefc9 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Seq2Seq Generation] Call encoder before expanding input_ids (#3370)

parent 39371ee4
...@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel): ...@@ -113,6 +113,7 @@ class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig config_class = BartConfig
base_model_prefix = "model" base_model_prefix = "model"
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -888,7 +889,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs, decoder_cached_states = past, None encoder_outputs, decoder_cached_states = past, None
else: else:
encoder_outputs, decoder_cached_states = past encoder_outputs, decoder_cached_states = past
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
......
...@@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -457,6 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_t5 load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer" base_model_prefix = "transformer"
encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...)
@property @property
def dummy_inputs(self): def dummy_inputs(self):
......
...@@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -895,6 +895,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size = batch_size effective_batch_size = batch_size
effective_batch_mult = 1 effective_batch_mult = 1
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# Expand input ids if num_beams > 1 or num_return_sequences > 1 # Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1: if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1] input_ids_len = input_ids.shape[-1]
...@@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -911,20 +926,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
# create empty decoder_input_ids # create empty decoder_input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
...@@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -933,6 +934,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
cur_len = 1 cur_len = 1
batch_idx = self.encoder_outputs_batch_dim_idx
assert (
batch_size == encoder_outputs[0].shape[batch_idx]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} "
expanded_idx = (
torch.arange(batch_size)
.view(-1, 1)
.repeat(1, num_beams * effective_batch_mult)
.view(-1)
.to(input_ids.device)
)
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
else: else:
encoder_outputs = None encoder_outputs = None
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment