"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9ff72433fa5a4d9f9e2f2c599e394480b581c614"
Commit b6d420c2 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make segment_labels optional

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/703

Differential Revision: D16072305

Pulled By: myleott

fbshipit-source-id: b77019bdcfbfb95f2817a29a74515bc8f5b682bf
parent 1757ef69
...@@ -101,8 +101,8 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -101,8 +101,8 @@ class MaskedLMModel(BaseFairseqModel):
parser.add_argument('--encoder-normalize-before', action='store_true', parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
def forward(self, src_tokens, segment_labels, **kwargs): def forward(self, src_tokens, segment_labels=None, **kwargs):
return self.encoder(src_tokens, segment_labels, **kwargs) return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs)
def max_positions(self): def max_positions(self):
return self.encoder.max_positions return self.encoder.max_positions
...@@ -192,7 +192,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -192,7 +192,7 @@ class MaskedLMEncoder(FairseqEncoder):
bias=False bias=False
) )
def forward(self, src_tokens, segment_labels, **unused): def forward(self, src_tokens, segment_labels=None, **unused):
""" """
Forward pass for Masked LM encoder. This first computes the token Forward pass for Masked LM encoder. This first computes the token
embedding using the token embedding matrix, position embeddings (if embedding using the token embedding matrix, position embeddings (if
...@@ -216,7 +216,10 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -216,7 +216,10 @@ class MaskedLMEncoder(FairseqEncoder):
this is specified in the input arguments. this is specified in the input arguments.
""" """
inner_states, sentence_rep = self.sentence_encoder(src_tokens, segment_labels) inner_states, sentence_rep = self.sentence_encoder(
src_tokens,
segment_labels=segment_labels,
)
x = inner_states[-1].transpose(0, 1) x = inner_states[-1].transpose(0, 1)
x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x))) x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x)))
......
...@@ -170,7 +170,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -170,7 +170,7 @@ class TransformerSentenceEncoder(nn.Module):
def forward( def forward(
self, self,
tokens: torch.Tensor, tokens: torch.Tensor,
segment_labels: torch.Tensor, segment_labels: torch.Tensor = None,
last_state_only: bool = False, last_state_only: bool = False,
positions: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment