You need to sign in or sign up before continuing.
Commit 7358296b authored by higgsfield's avatar higgsfield Committed by Myle Ott
Browse files

fixed output_proj's input_dim in attention (#226)

parent 28adb200
...@@ -217,7 +217,7 @@ class AttentionLayer(nn.Module): ...@@ -217,7 +217,7 @@ class AttentionLayer(nn.Module):
super().__init__() super().__init__()
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False) self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False) self.output_proj = Linear(input_embed_dim + output_embed_dim, output_embed_dim, bias=False)
def forward(self, input, source_hids, encoder_padding_mask): def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim # input: bsz x input_embed_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment