Unverified Commit 390c1285 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Encoder-Decoder] Force models outputs to always have batch_size as their first dim (#3536)

* solve conflicts

* improve comments
parent ab5d06a0
...@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel): ...@@ -116,7 +116,6 @@ class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig config_class = BartConfig
base_model_prefix = "model" base_model_prefix = "model"
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
encoder_outputs_batch_dim_idx = 1 # outputs shaped (seq_len, bs, ...)
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -294,7 +293,10 @@ class BartEncoder(nn.Module): ...@@ -294,7 +293,10 @@ class BartEncoder(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
encoder_states.append(x) encoder_states.append(x)
# T x B x C -> B x T x C
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
x = x.transpose(0, 1)
return x, encoder_states, all_attentions return x, encoder_states, all_attentions
...@@ -448,7 +450,11 @@ class BartDecoder(nn.Module): ...@@ -448,7 +450,11 @@ class BartDecoder(nn.Module):
x = self.layernorm_embedding(x) x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(0, 1) # (seq_len, BS, model_dim)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# decoder layers # decoder layers
all_hidden_states = () all_hidden_states = ()
all_self_attns = () all_self_attns = ()
...@@ -477,9 +483,10 @@ class BartDecoder(nn.Module): ...@@ -477,9 +483,10 @@ class BartDecoder(nn.Module):
if self.output_attentions: if self.output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
# Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim) # Convert to standart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states] all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
x = x.transpose(0, 1) x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if self.output_past: if self.output_past:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache) next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
...@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -930,10 +937,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
layer_past_new = { layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
} }
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
reordered_past.append(layer_past_new) reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(1, beam_idx)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx) new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past) past = ((new_enc_out, new_enc_mask), reordered_past)
......
...@@ -457,7 +457,6 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -457,7 +457,6 @@ class T5PreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_t5 load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer" base_model_prefix = "transformer"
encoder_outputs_batch_dim_idx = 0 # outputs shaped (bs, ...)
@property @property
def dummy_inputs(self): def dummy_inputs(self):
......
...@@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -948,18 +948,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
cur_len = 1 cur_len = 1
batch_idx = self.encoder_outputs_batch_dim_idx
assert ( assert (
batch_size == encoder_outputs[0].shape[batch_idx] batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[1]} " ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
expanded_idx = (
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
torch.arange(batch_size) torch.arange(batch_size)
.view(-1, 1) .view(-1, 1)
.repeat(1, num_beams * effective_batch_mult) .repeat(1, num_beams * effective_batch_mult)
.view(-1) .view(-1)
.to(input_ids.device) .to(input_ids.device)
) )
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:]) # expand encoder_outputs
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
else: else:
encoder_outputs = None encoder_outputs = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment