Commit eb23a242 authored by Xiaohui Zhang's avatar Xiaohui Zhang Committed by Facebook GitHub Bot
Browse files

Support GroupNorm and re-ordering Convolution/MHA in Conformer (#2320)

Summary:
Add an option to use GroupNorm rather than BatchNorm1d, and another option to re-order Convolution/MHA modules in Conformer model.

Pull Request resolved: https://github.com/pytorch/audio/pull/2320

Reviewed By: hwangjeff

Differential Revision: D35422112

Pulled By: xiaohui-zhang

fbshipit-source-id: 360a8aaa37b883b0f656da2e4f654e86688ac270
parent 16958d5b
...@@ -22,7 +22,9 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -22,7 +22,9 @@ class _ConvolutionModule(torch.nn.Module):
input_dim (int): input dimension. input_dim (int): input dimension.
num_channels (int): number of depthwise convolution layer input channels. num_channels (int): number of depthwise convolution layer input channels.
depthwise_kernel_size (int): kernel size of depthwise convolution layer. depthwise_kernel_size (int): kernel size of depthwise convolution layer.
dropout (float, optional): dropout probability. (Default: 0.0)
bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``) bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
""" """
def __init__( def __init__(
...@@ -30,8 +32,9 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -30,8 +32,9 @@ class _ConvolutionModule(torch.nn.Module):
input_dim: int, input_dim: int,
num_channels: int, num_channels: int,
depthwise_kernel_size: int, depthwise_kernel_size: int,
bias: bool = False,
dropout: float = 0.0, dropout: float = 0.0,
bias: bool = False,
use_group_norm: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding." assert (depthwise_kernel_size - 1) % 2 == 0, "depthwise_kernel_size must be odd to achieve 'SAME' padding."
...@@ -55,7 +58,9 @@ class _ConvolutionModule(torch.nn.Module): ...@@ -55,7 +58,9 @@ class _ConvolutionModule(torch.nn.Module):
groups=num_channels, groups=num_channels,
bias=bias, bias=bias,
), ),
torch.nn.BatchNorm1d(num_channels), torch.nn.GroupNorm(num_groups=1, num_channels=num_channels)
if use_group_norm
else torch.nn.BatchNorm1d(num_channels),
torch.nn.SiLU(), torch.nn.SiLU(),
torch.nn.Conv1d( torch.nn.Conv1d(
num_channels, num_channels,
...@@ -122,6 +127,10 @@ class ConformerLayer(torch.nn.Module): ...@@ -122,6 +127,10 @@ class ConformerLayer(torch.nn.Module):
num_attention_heads (int): number of attention heads. num_attention_heads (int): number of attention heads.
depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer. depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
dropout (float, optional): dropout probability. (Default: 0.0) dropout (float, optional): dropout probability. (Default: 0.0)
use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
in the convolution module. (Default: ``False``)
convolution_first (bool, optional): apply the convolution module ahead of
the attention module. (Default: ``False``)
""" """
def __init__( def __init__(
...@@ -131,6 +140,8 @@ class ConformerLayer(torch.nn.Module): ...@@ -131,6 +140,8 @@ class ConformerLayer(torch.nn.Module):
num_attention_heads: int, num_attention_heads: int,
depthwise_conv_kernel_size: int, depthwise_conv_kernel_size: int,
dropout: float = 0.0, dropout: float = 0.0,
use_group_norm: bool = False,
convolution_first: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -144,12 +155,22 @@ class ConformerLayer(torch.nn.Module): ...@@ -144,12 +155,22 @@ class ConformerLayer(torch.nn.Module):
input_dim=input_dim, input_dim=input_dim,
num_channels=input_dim, num_channels=input_dim,
depthwise_kernel_size=depthwise_conv_kernel_size, depthwise_kernel_size=depthwise_conv_kernel_size,
bias=True,
dropout=dropout, dropout=dropout,
bias=True,
use_group_norm=use_group_norm,
) )
self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout) self.ffn2 = _FeedForwardModule(input_dim, ffn_dim, dropout=dropout)
self.final_layer_norm = torch.nn.LayerNorm(input_dim) self.final_layer_norm = torch.nn.LayerNorm(input_dim)
self.convolution_first = convolution_first
def _apply_convolution(self, input: torch.Tensor) -> torch.Tensor:
residual = input
input = input.transpose(0, 1)
input = self.conv_module(input)
input = input.transpose(0, 1)
input = residual + input
return input
def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: def forward(self, input: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
r""" r"""
...@@ -164,6 +185,9 @@ class ConformerLayer(torch.nn.Module): ...@@ -164,6 +185,9 @@ class ConformerLayer(torch.nn.Module):
x = self.ffn1(input) x = self.ffn1(input)
x = x * 0.5 + residual x = x * 0.5 + residual
if self.convolution_first:
x = self._apply_convolution(x)
residual = x residual = x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
x, _ = self.self_attn( x, _ = self.self_attn(
...@@ -176,11 +200,8 @@ class ConformerLayer(torch.nn.Module): ...@@ -176,11 +200,8 @@ class ConformerLayer(torch.nn.Module):
x = self.self_attn_dropout(x) x = self.self_attn_dropout(x)
x = x + residual x = x + residual
residual = x if not self.convolution_first:
x = x.transpose(0, 1) x = self._apply_convolution(x)
x = self.conv_module(x)
x = x.transpose(0, 1)
x = residual + x
residual = x residual = x
x = self.ffn2(x) x = self.ffn2(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment