Unverified Commit 0a921b64 authored by Max Del's avatar Max Del Committed by GitHub
Browse files

BART & FSMT: fix decoder not returning hidden states from the last layer (#8597)



* Fix decoder not returning hidden states from the last layer

* Resolve conflict

* Change the way to gather hidden states

* Add decoder hidden states test

* Make pytest and black happy

* Remove redundant line

* remove new line
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 81fe0bf0
...@@ -610,6 +610,12 @@ class BartDecoder(nn.Module): ...@@ -610,6 +610,12 @@ class BartDecoder(nn.Module):
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,) all_cross_attentions += (layer_cross_attn,)
# add hidden states from the last decoder layer
if output_hidden_states:
x = x.transpose(0, 1)
all_hidden_states += (x,)
x = x.transpose(0, 1)
if self.layer_norm: # if config.add_final_layer_norm (mBART) if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x) x = self.layer_norm(x)
......
...@@ -692,6 +692,12 @@ class FSMTDecoder(nn.Module): ...@@ -692,6 +692,12 @@ class FSMTDecoder(nn.Module):
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attns += (layer_cross_attn,) all_cross_attns += (layer_cross_attn,)
# add hidden states from the last decoder layer
if output_hidden_states:
x = x.transpose(0, 1)
all_hidden_states += (x,)
x = x.transpose(0, 1)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
x = x.transpose(0, 1) x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
......
...@@ -659,12 +659,14 @@ class ModelTesterMixin: ...@@ -659,12 +659,14 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs["hidden_states"] if "hidden_states" in outputs else outputs[-1]
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
) )
self.assertEqual(len(hidden_states), expected_num_layers) self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"): if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
...@@ -677,6 +679,19 @@ class ModelTesterMixin: ...@@ -677,6 +679,19 @@ class ModelTesterMixin:
[seq_length, self.model_tester.hidden_size], [seq_length, self.model_tester.hidden_size],
) )
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment