"...text-generation-inference.git" did not exist on "610bb1f97811af2d80d16efd96f313012beb65aa"
Commit 16958d5b authored by Xiaohui Zhang's avatar Xiaohui Zhang Committed by Facebook GitHub Bot
Browse files

Add an option to use Tanh instead of ReLU in RNNT joiner (#2319)

Summary:
Add an option to use Tanh instead of ReLU in RNNT joiner, which enables better training performance sometimes.

 ---

Pull Request resolved: https://github.com/pytorch/audio/pull/2319

Reviewed By: hwangjeff

Differential Revision: D35422122

Pulled By: xiaohui-zhang

fbshipit-source-id: c6a0f8b25936e47081110af046b57d0e8751f9a2
parent f7afe29e
...@@ -377,12 +377,20 @@ class _Joiner(torch.nn.Module): ...@@ -377,12 +377,20 @@ class _Joiner(torch.nn.Module):
Args: Args:
input_dim (int): source and target input dimension. input_dim (int): source and target input dimension.
output_dim (int): output dimension. output_dim (int): output dimension.
activation (str, optional): activation function to use in the joiner
Must be one of ("relu", "tanh"). (Default: "relu")
""" """
def __init__(self, input_dim: int, output_dim: int) -> None: def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
super().__init__() super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True) self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
self.relu = torch.nn.ReLU() if activation == "relu":
self.activation = torch.nn.ReLU()
elif activation == "tanh":
self.activation = torch.nn.Tanh()
else:
raise ValueError(f"Unsupported activation {activation}")
def forward( def forward(
self, self,
...@@ -419,8 +427,8 @@ class _Joiner(torch.nn.Module): ...@@ -419,8 +427,8 @@ class _Joiner(torch.nn.Module):
number of valid elements along dim 2 for i-th batch element in joint network output. number of valid elements along dim 2 for i-th batch element in joint network output.
""" """
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous() joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
relu_out = self.relu(joint_encodings) activation_out = self.activation(joint_encodings)
output = self.linear(relu_out) output = self.linear(activation_out)
return output, source_lengths, target_lengths return output, source_lengths, target_lengths
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment