Commit 48c4c6d3 authored by ngimel's avatar ngimel Committed by Myle Ott
Browse files

use implicit padding when possible (#152)

parent 66ee3df9
......@@ -98,9 +98,13 @@ class FConvEncoder(FairseqEncoder):
for (out_channels, kernel_size) in convolutions:
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
if kernel_size % 2 == 1:
padding = kernel_size //2
else:
padding = 0
self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size,
dropout=dropout)
dropout=dropout, padding=padding)
)
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
......@@ -121,10 +125,14 @@ class FConvEncoder(FairseqEncoder):
for proj, conv in zip(self.projections, self.convolutions):
residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training)
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x)
if conv.kernel_size[0] % 2 == 1:
# padding is implicit in the conv
x = conv(x)
else:
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
x = conv(x)
x = F.glu(x, dim=2)
x = (x + residual) * math.sqrt(0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment