Commit 7a5996fd authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

use fused layer norm in transformer sentence encoder (#702)

Summary:
We can later get rid off `BertLayerNorm` also, as I think the implementation of that is exactly same as `LayerNorm`. (will confirm with jingfeidu on that).
But this should be drop and replace.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/702

Differential Revision: D15213116

Pulled By: myleott

fbshipit-source-id: ba5c00e1129a4443ef5d3d8bebd0bb6c6ee3b188
parent 657a8836
......@@ -10,7 +10,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import gelu, MultiheadAttention, BertLayerNorm
from fairseq.modules import gelu, MultiheadAttention, BertLayerNorm, LayerNorm
class TransformerSentenceEncoderLayer(nn.Module):
......@@ -19,7 +19,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
models.
If the flag use_bert_layer_norm is set then we use the custom
BertLayerNorm module instead of nn.LayerNorm.
BertLayerNorm module instead of LayerNorm.
"""
def __init__(
......@@ -52,7 +52,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
self.self_attn_layer_norm = (
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else nn.LayerNorm(self.embedding_dim, eps=1e-12)
else LayerNorm(self.embedding_dim, eps=1e-12)
)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
......@@ -61,7 +61,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
self.final_layer_norm = (
BertLayerNorm(self.embedding_dim)
if use_bert_layer_norm
else nn.LayerNorm(self.embedding_dim, eps=1e-12)
else LayerNorm(self.embedding_dim, eps=1e-12)
)
def _maybe_layer_norm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment