"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "32e1992924929a9b79e880ed6f5bdc74089e8c73"
Commit 7a5996fd authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

use fused layer norm in transformer sentence encoder (#702)

Summary:
We can later get rid off `BertLayerNorm` also, as I think the implementation of that is exactly same as `LayerNorm`. (will confirm with jingfeidu on that).
But this should be drop and replace.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/702

Differential Revision: D15213116

Pulled By: myleott

fbshipit-source-id: ba5c00e1129a4443ef5d3d8bebd0bb6c6ee3b188
parent 657a8836
...@@ -10,7 +10,7 @@ import math ...@@ -10,7 +10,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import gelu, MultiheadAttention, BertLayerNorm from fairseq.modules import gelu, MultiheadAttention, BertLayerNorm, LayerNorm
class TransformerSentenceEncoderLayer(nn.Module): class TransformerSentenceEncoderLayer(nn.Module):
...@@ -19,7 +19,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -19,7 +19,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
models. models.
If the flag use_bert_layer_norm is set then we use the custom If the flag use_bert_layer_norm is set then we use the custom
BertLayerNorm module instead of nn.LayerNorm. BertLayerNorm module instead of LayerNorm.
""" """
def __init__( def __init__(
...@@ -52,7 +52,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -52,7 +52,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
self.self_attn_layer_norm = ( self.self_attn_layer_norm = (
BertLayerNorm(self.embedding_dim) BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm if use_bert_layer_norm
else nn.LayerNorm(self.embedding_dim, eps=1e-12) else LayerNorm(self.embedding_dim, eps=1e-12)
) )
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
...@@ -61,7 +61,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -61,7 +61,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
self.final_layer_norm = ( self.final_layer_norm = (
BertLayerNorm(self.embedding_dim) BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm if use_bert_layer_norm
else nn.LayerNorm(self.embedding_dim, eps=1e-12) else LayerNorm(self.embedding_dim, eps=1e-12)
) )
def _maybe_layer_norm( def _maybe_layer_norm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment