Commit 2a3adcdc authored by Ruty Rinott's avatar Ruty Rinott Committed by Facebook Github Bot
Browse files

Fix biTransformer export (#583)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/583

D14610694 fixed issues in layerNorm exporting by making it conditional.  D15260838 changed the implementation of TransformerDecoderLayer to the one under transformer, thus losing the fix. Bringing it back here.

Reviewed By: myleott, geof90, liaimi

Differential Revision: D15357119

fbshipit-source-id: e29e053ca5beca0008d7a8dad9880a483a14c7b9
parent 74c936dc
......@@ -601,7 +601,11 @@ class TransformerDecoderLayer(nn.Module):
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, 'char_inputs', False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
......@@ -611,12 +615,12 @@ class TransformerDecoderLayer(nn.Module):
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment