Commit 411b5dcf authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Adjust Conformer args (#2223)

Summary:
Orders and names Conformer's initializer args to be more consistent with Emformer's.

Pull Request resolved: https://github.com/pytorch/audio/pull/2223

Reviewed By: mthrok

Differential Revision: D34226177

Pulled By: hwangjeff

fbshipit-source-id: 111c7ff27841aeac302ea5f6f7b50cc72c570829
parent bc0fcadb
...@@ -7,10 +7,10 @@ class ConformerTestImpl(TestBaseMixin): ...@@ -7,10 +7,10 @@ class ConformerTestImpl(TestBaseMixin):
def _gen_model(self): def _gen_model(self):
conformer = ( conformer = (
Conformer( Conformer(
num_layers=4,
input_dim=80, input_dim=80,
num_heads=4,
ffn_dim=128, ffn_dim=128,
num_attention_heads=4, num_layers=4,
depthwise_conv_kernel_size=31, depthwise_conv_kernel_size=31,
dropout=0.1, dropout=0.1,
) )
......
...@@ -196,19 +196,19 @@ class Conformer(torch.nn.Module): ...@@ -196,19 +196,19 @@ class Conformer(torch.nn.Module):
[:footcite:`gulati2020conformer`]. [:footcite:`gulati2020conformer`].
Args: Args:
num_layers (int): number of Conformer layers to instantiate.
input_dim (int): input dimension. input_dim (int): input dimension.
ffn_dim (int): hidden layer dimension of feedforward network. num_heads (int): number of attention heads in each Conformer layer.
num_attention_heads (int): number of attention heads. ffn_dim (int): hidden layer dimension of feedforward networks.
depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer. num_layers (int): number of Conformer layers to instantiate.
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
dropout (float, optional): dropout probability. (Default: 0.0) dropout (float, optional): dropout probability. (Default: 0.0)
Examples: Examples:
>>> conformer = Conformer( >>> conformer = Conformer(
>>> num_layers=4,
>>> input_dim=80, >>> input_dim=80,
>>> num_heads=4,
>>> ffn_dim=128, >>> ffn_dim=128,
>>> num_attention_heads=4, >>> num_layers=4,
>>> depthwise_conv_kernel_size=31, >>> depthwise_conv_kernel_size=31,
>>> ) >>> )
>>> lengths = torch.randint(1, 400, (10,)) # (batch,) >>> lengths = torch.randint(1, 400, (10,)) # (batch,)
...@@ -218,10 +218,10 @@ class Conformer(torch.nn.Module): ...@@ -218,10 +218,10 @@ class Conformer(torch.nn.Module):
def __init__( def __init__(
self, self,
num_layers: int,
input_dim: int, input_dim: int,
num_heads: int,
ffn_dim: int, ffn_dim: int,
num_attention_heads: int, num_layers: int,
depthwise_conv_kernel_size: int, depthwise_conv_kernel_size: int,
dropout: float = 0.0, dropout: float = 0.0,
): ):
...@@ -232,7 +232,7 @@ class Conformer(torch.nn.Module): ...@@ -232,7 +232,7 @@ class Conformer(torch.nn.Module):
ConformerLayer( ConformerLayer(
input_dim, input_dim,
ffn_dim, ffn_dim,
num_attention_heads, num_heads,
depthwise_conv_kernel_size, depthwise_conv_kernel_size,
dropout, dropout,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment