Commit 9e1c880f authored by Dmytro Okhonko's avatar Dmytro Okhonko Committed by Facebook Github Bot
Browse files

FairseqEncoderModel

Summary: Base class for encoder-only models. Some models doesn't have decoder part.

Reviewed By: myleott

Differential Revision: D14413406

fbshipit-source-id: f36473b91dcf3c835fd6d50e2eb6002afa75f11a
parent 7fc9a3be
......@@ -17,6 +17,7 @@ from .fairseq_model import (
FairseqModel, # noqa: F401
FairseqMultiModel, # noqa: F401
FairseqLanguageModel, # noqa: F401
FairseqEncoderModel, # noqa: F401
)
from .composite_encoder import CompositeEncoder # noqa: F401
......
......@@ -297,3 +297,43 @@ class FairseqLanguageModel(BaseFairseqModel):
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
Args:
encoder (FairseqEncoder): the encoder
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
assert isinstance(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate logits.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, seq_len, vocab)`
"""
return self.encoder(src_tokens, src_lengths)
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
@property
def supported_targets(self):
return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment