Commit 411b5dcf authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Adjust Conformer args (#2223)

Summary:
Orders and names Conformer's initializer args to be more consistent with Emformer's.

Pull Request resolved: https://github.com/pytorch/audio/pull/2223

Reviewed By: mthrok

Differential Revision: D34226177

Pulled By: hwangjeff

fbshipit-source-id: 111c7ff27841aeac302ea5f6f7b50cc72c570829
parent bc0fcadb
......@@ -7,10 +7,10 @@ class ConformerTestImpl(TestBaseMixin):
def _gen_model(self):
conformer = (
Conformer(
num_layers=4,
input_dim=80,
num_heads=4,
ffn_dim=128,
num_attention_heads=4,
num_layers=4,
depthwise_conv_kernel_size=31,
dropout=0.1,
)
......
......@@ -196,19 +196,19 @@ class Conformer(torch.nn.Module):
[:footcite:`gulati2020conformer`].
Args:
num_layers (int): number of Conformer layers to instantiate.
input_dim (int): input dimension.
ffn_dim (int): hidden layer dimension of feedforward network.
num_attention_heads (int): number of attention heads.
depthwise_conv_kernel_size (int): kernel size of depthwise convolution layer.
num_heads (int): number of attention heads in each Conformer layer.
ffn_dim (int): hidden layer dimension of feedforward networks.
num_layers (int): number of Conformer layers to instantiate.
depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
dropout (float, optional): dropout probability. (Default: 0.0)
Examples:
>>> conformer = Conformer(
>>> num_layers=4,
>>> input_dim=80,
>>> num_heads=4,
>>> ffn_dim=128,
>>> num_attention_heads=4,
>>> num_layers=4,
>>> depthwise_conv_kernel_size=31,
>>> )
>>> lengths = torch.randint(1, 400, (10,)) # (batch,)
......@@ -218,10 +218,10 @@ class Conformer(torch.nn.Module):
def __init__(
self,
num_layers: int,
input_dim: int,
num_heads: int,
ffn_dim: int,
num_attention_heads: int,
num_layers: int,
depthwise_conv_kernel_size: int,
dropout: float = 0.0,
):
......@@ -232,7 +232,7 @@ class Conformer(torch.nn.Module):
ConformerLayer(
input_dim,
ffn_dim,
num_attention_heads,
num_heads,
depthwise_conv_kernel_size,
dropout,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment