Commit 9e1c880f authored by Dmytro Okhonko's avatar Dmytro Okhonko Committed by Facebook Github Bot
Browse files

FairseqEncoderModel

Summary: Base class for encoder-only models. Some models doesn't have decoder part.

Reviewed By: myleott

Differential Revision: D14413406

fbshipit-source-id: f36473b91dcf3c835fd6d50e2eb6002afa75f11a
parent 7fc9a3be
...@@ -17,6 +17,7 @@ from .fairseq_model import ( ...@@ -17,6 +17,7 @@ from .fairseq_model import (
FairseqModel, # noqa: F401 FairseqModel, # noqa: F401
FairseqMultiModel, # noqa: F401 FairseqMultiModel, # noqa: F401
FairseqLanguageModel, # noqa: F401 FairseqLanguageModel, # noqa: F401
FairseqEncoderModel, # noqa: F401
) )
from .composite_encoder import CompositeEncoder # noqa: F401 from .composite_encoder import CompositeEncoder # noqa: F401
......
...@@ -297,3 +297,43 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -297,3 +297,43 @@ class FairseqLanguageModel(BaseFairseqModel):
def remove_head(self): def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed""" """Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError() raise NotImplementedError()
class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
Args:
encoder (FairseqEncoder): the encoder
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
assert isinstance(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate logits.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, seq_len, vocab)`
"""
return self.encoder(src_tokens, src_lengths)
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
@property
def supported_targets(self):
return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment