Commit 4e9ecb80 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Make XLM torchscipt Export-able (#765)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/765

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/614

This diff has changes needed to make XLM torchscript exportable.

Reviewed By: bethebunny

Differential Revision: D15497208

fbshipit-source-id: fd9645119e154e3c397f147acf9144d661d9a5c8
parent 65f46473
......@@ -89,6 +89,7 @@ class TransformerSentenceEncoder(nn.Module):
add_bias_kv: bool = False,
add_zero_attn: bool = False,
embed_scale: float = None,
export: bool = False,
) -> None:
super().__init__()
......@@ -139,13 +140,14 @@ class TransformerSentenceEncoder(nn.Module):
activation_fn=activation_fn,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
export=export,
)
for _ in range(num_encoder_layers)
]
)
if encoder_normalize_before:
self.emb_layer_norm = LayerNorm(self.embedding_dim)
self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
else:
self.emb_layer_norm = None
......@@ -184,7 +186,7 @@ class TransformerSentenceEncoder(nn.Module):
# account for padding while computing the representation
if padding_mask is not None:
x *= (~padding_mask).unsqueeze(-1).type_as(x)
x *= (1 - padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
......
......@@ -33,6 +33,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
activation_fn: str = 'relu',
add_bias_kv: bool = False,
add_zero_attn: bool = False,
export: bool = False,
) -> None:
super().__init__()
......@@ -52,12 +53,12 @@ class TransformerSentenceEncoderLayer(nn.Module):
)
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment