Commit 8446cb63 authored by Ilia Cherniavskii's avatar Ilia Cherniavskii Committed by Facebook Github Bot
Browse files

TorchScript-ify BERT training (#887)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/887

Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1052

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1250

Adding config parameter "use_torchscript" that enables use of TS for BERT
training

Reviewed By: chenyangyu1988

Differential Revision: D17872083

fbshipit-source-id: 00ac4b04e7f26aa56fe84fe9feaded676d6deb71
parent e4047852
...@@ -95,6 +95,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -95,6 +95,7 @@ class TransformerSentenceEncoder(nn.Module):
freeze_embeddings: bool = False, freeze_embeddings: bool = False,
n_trans_layers_to_freeze: int = 0, n_trans_layers_to_freeze: int = 0,
export: bool = False, export: bool = False,
traceable: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -108,6 +109,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -108,6 +109,7 @@ class TransformerSentenceEncoder(nn.Module):
self.use_position_embeddings = use_position_embeddings self.use_position_embeddings = use_position_embeddings
self.apply_bert_init = apply_bert_init self.apply_bert_init = apply_bert_init
self.learned_pos_embedding = learned_pos_embedding self.learned_pos_embedding = learned_pos_embedding
self.traceable = traceable
self.embed_tokens = nn.Embedding( self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx self.vocab_size, self.embedding_dim, self.padding_idx
...@@ -182,7 +184,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -182,7 +184,7 @@ class TransformerSentenceEncoder(nn.Module):
# compute padding mask. This is needed for multi-head attention # compute padding mask. This is needed for multi-head attention
padding_mask = tokens.eq(self.padding_idx) padding_mask = tokens.eq(self.padding_idx)
if not padding_mask.any(): if not self.traceable and not padding_mask.any():
padding_mask = None padding_mask = None
x = self.embed_tokens(tokens) x = self.embed_tokens(tokens)
...@@ -229,4 +231,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -229,4 +231,7 @@ class TransformerSentenceEncoder(nn.Module):
if last_state_only: if last_state_only:
inner_states = [x] inner_states = [x]
if self.traceable:
return torch.stack(inner_states), sentence_rep
else:
return inner_states, sentence_rep return inner_states, sentence_rep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment