Unverified Commit a8339b9e authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Fix bart shape comment (#8423)

parent 46509d1c
......@@ -585,7 +585,7 @@ class BartDecoder(nn.Module):
x = F.dropout(x, p=self.dropout, training=self.training)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
# Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
......
......@@ -570,7 +570,7 @@ class TFBartDecoder(Layer):
x = self.layernorm_embedding(x + positions)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
# Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
x = tf.transpose(x, perm=(1, 0, 2))
assert len(shape_list(encoder_hidden_states)) == 3, "encoder_hidden_states must be a 3D tensor"
encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment