Commit b6d420c2 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make segment_labels optional

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/703

Differential Revision: D16072305

Pulled By: myleott

fbshipit-source-id: b77019bdcfbfb95f2817a29a74515bc8f5b682bf
parent 1757ef69
......@@ -101,8 +101,8 @@ class MaskedLMModel(BaseFairseqModel):
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
def forward(self, src_tokens, segment_labels, **kwargs):
return self.encoder(src_tokens, segment_labels, **kwargs)
def forward(self, src_tokens, segment_labels=None, **kwargs):
return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
def max_positions(self):
return self.encoder.max_positions
......@@ -192,7 +192,7 @@ class MaskedLMEncoder(FairseqEncoder):
bias=False
)
def forward(self, src_tokens, segment_labels, **unused):
def forward(self, src_tokens, segment_labels=None, **unused):
"""
Forward pass for Masked LM encoder. This first computes the token
embedding using the token embedding matrix, position embeddings (if
......@@ -216,7 +216,10 @@ class MaskedLMEncoder(FairseqEncoder):
this is specified in the input arguments.
"""
inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels)
inner_states, sentence_rep = self.sentence_encoder(
src_tokens,
segment_labels=segment_labels,
)
x = inner_states[-1].transpose(0, 1)
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
......
......@@ -170,7 +170,7 @@ class TransformerSentenceEncoder(nn.Module):
def forward(
self,
tokens: torch.Tensor,
segment_labels: torch.Tensor,
segment_labels: torch.Tensor = None,
last_state_only: bool = False,
positions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment