Commit 34c9ebf0 authored by Felix Wu's avatar Felix Wu Committed by Facebook Github Bot
Browse files

Fixing a bug of DynamicConv in the unfolding mode (#593)

Summary:
The unfold1d.py has the same name as the function `unfold1d` function, which will cause an error when using DynamicConv1dTBC with `unfold=True`.
This doesn't affect the NMT models which don't use the unfolding mode though.

I rename `unfold1d.py` as `unfold.py` to fix this bug.

Originally we would get `TypeError` when running this code:
```
import torch
from fairseq.modules import LightweightConv1dTBC, DynamicConv1dTBC

x = torch.rand(4, 10, 8)
m = LightweightConv1dTBC(8, 4, 3)
o = m(x, unfold=True)

m = DynamicConv1dTBC(8, 4, 3)
o = m(x, unfold=True)
```
Pull Request resolved: https://github.com/pytorch/fairseq/pull/593

Differential Revision: D14597117

Pulled By: myleott

fbshipit-source-id: 59752fd7ff62c53a4aba8b56b83155291e5f5792
parent 8e66a12f
...@@ -23,7 +23,7 @@ from .mean_pool_gating_network import MeanPoolGatingNetwork ...@@ -23,7 +23,7 @@ from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .unfold1d import unfold1d from .unfold import unfold1d
__all__ = [ __all__ = [
'AdaptiveInput', 'AdaptiveInput',
......
...@@ -10,7 +10,7 @@ import torch.nn as nn ...@@ -10,7 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.modules import unfold1d from .unfold import unfold1d
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True):
......
...@@ -12,7 +12,7 @@ import torch.nn as nn ...@@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.modules import unfold1d from .unfold import unfold1d
class LightweightConv1d(nn.Module): class LightweightConv1d(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment