Commit 4e9ecb80 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Make XLM torchscipt Export-able (#765)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/765

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/614

This diff has changes needed to make XLM torchscript exportable.

Reviewed By: bethebunny

Differential Revision: D15497208

fbshipit-source-id: fd9645119e154e3c397f147acf9144d661d9a5c8
parent 65f46473
...@@ -89,6 +89,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -89,6 +89,7 @@ class TransformerSentenceEncoder(nn.Module):
add_bias_kv: bool = False, add_bias_kv: bool = False,
add_zero_attn: bool = False, add_zero_attn: bool = False,
embed_scale: float = None, embed_scale: float = None,
export: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -139,13 +140,14 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -139,13 +140,14 @@ class TransformerSentenceEncoder(nn.Module):
activation_fn=activation_fn, activation_fn=activation_fn,
add_bias_kv=add_bias_kv, add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn, add_zero_attn=add_zero_attn,
export=export,
) )
for _ in range(num_encoder_layers) for _ in range(num_encoder_layers)
] ]
) )
if encoder_normalize_before: if encoder_normalize_before:
self.emb_layer_norm = LayerNorm(self.embedding_dim) self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
else: else:
self.emb_layer_norm = None self.emb_layer_norm = None
...@@ -184,7 +186,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -184,7 +186,7 @@ class TransformerSentenceEncoder(nn.Module):
# account for padding while computing the representation # account for padding while computing the representation
if padding_mask is not None: if padding_mask is not None:
x *= (~padding_mask).unsqueeze(-1).type_as(x) x *= (1 - padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
......
...@@ -33,6 +33,7 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -33,6 +33,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
activation_fn: str = 'relu', activation_fn: str = 'relu',
add_bias_kv: bool = False, add_bias_kv: bool = False,
add_zero_attn: bool = False, add_zero_attn: bool = False,
export: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -52,12 +53,12 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -52,12 +53,12 @@ class TransformerSentenceEncoderLayer(nn.Module):
) )
# layer norm associated with the self attention layer # layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN # layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim) self.final_layer_norm = LayerNorm(self.embedding_dim, export=export)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment