Commit 40f16872 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add return_all_hiddens flag to hub interface

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/909

Differential Revision: D16532919

Pulled By: myleott

fbshipit-source-id: 16ce884cf3d84579026e4406a75ba3c01a128dbd
parent 17fcc72a
......@@ -20,6 +20,7 @@ Model | Description | # params | Download
```
>>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```
##### Apply Byte-Pair Encoding (BPE) to input text:
......@@ -31,9 +32,16 @@ tensor([ 0, 31414, 232, 328, 2])
##### Extract features from RoBERTa:
```
>>> features = roberta.extract_features(tokens)
>>> features.size()
>>> last_layer_features = roberta.extract_features(tokens)
>>> last_layer_features.size()
torch.Size([1, 5, 1024])
>>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
>>> len(all_layers)
25
>>> torch.all(all_layers[-1] == last_layer_features)
tensor(1, dtype=torch.uint8)
```
##### Use RoBERTa for sentence-pair classification tasks:
......
......@@ -33,6 +33,7 @@ class RobertaHubInterface(nn.Module):
Load RoBERTa::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
Apply Byte-Pair Encoding (BPE) to input text::
......@@ -42,10 +43,16 @@ class RobertaHubInterface(nn.Module):
Extract features from RoBERTa::
>>> features = roberta.extract_features(tokens)
>>> features.size()
>>> last_layer_features = roberta.extract_features(tokens)
>>> last_layer_features.size()
torch.Size([1, 5, 1024])
>>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
>>> len(all_layers)
25
>>> torch.all(all_layers[-1] == last_layer_features)
tensor(1, dtype=torch.uint8)
Use RoBERTa for sentence-pair classification tasks::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned
......@@ -100,11 +107,20 @@ class RobertaHubInterface(nn.Module):
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True)
return tokens.long()
def extract_features(self, tokens: torch.LongTensor) -> torch.Tensor:
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features, _ = self.model(tokens.to(device=self.device), features_only=True)
return features
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra['inner_states']
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment