"vscode:/vscode.git/clone" did not exist on "dfd0c5fd8306a64cba5d281b5a419a815792d94c"
Unverified Commit 78b08c26 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Update `extract_features` of Wav2Vec2Model (#1776)

* [BC-Breaking] Update `extract_features` of Wav2Vec2Model

Originally, `extract_features` method was returning the result from
the convolutional feature extractor module.

The features commonly used in downstream tasks are outputs from intermediate
layers of transformer block in encoder.

This commit update the behavior of `extract_features` to allow selectively
retrieve such features.
parent 599a82b7
......@@ -86,9 +86,10 @@ class TestFairseqIntegration(TorchaudioTestCase):
imported = import_fairseq_model(original, 28).eval()
x = torch.randn(batch_size, num_frames)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.extract_features(x)
self.assertEqual(ref, hyp)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs['layer_results']):
self.assertEqual(hyp[i], ref.transpose(0, 1))
@parameterized.expand(PRETRAINED_CONFIGS)
def test_recreate_pretrained_model(self, config, factory_func):
......
......@@ -14,11 +14,16 @@ from torchaudio_unittest.common_utils import (
)
from parameterized import parameterized
def _name_func(testcase_func, _, param):
return f"{testcase_func.__name__}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([
(wav2vec2_base, ),
(wav2vec2_large, ),
(wav2vec2_large_lv60k, ),
])
], name_func=_name_func)
class TestWav2Vec2Model(TorchaudioTestCase):
......@@ -47,20 +52,33 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self._smoke_test(torch.device('cuda'), dtype)
@factory_funcs
def test_feature_extractor_smoke_test(self, factory_func):
def test_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval()
num_layers = len(model.encoder.transformer.layers)
torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
features, lengths = model.extract_features(waveforms, lengths)
# Not providing num_layers returns all the intermediate features from
# tranformer layers
all_features, lengths_ = model.extract_features(waveforms, lengths, num_layers=None)
assert len(all_features) == num_layers
for features in all_features:
assert features.ndim == 3
assert features.shape[0] == batch_size
assert lengths.shape == torch.Size([batch_size])
assert lengths_.shape == torch.Size([batch_size])
# Limiting the number of layers to `l`.
for l in range(1, num_layers + 1):
features, lengths_ = model.extract_features(waveforms, lengths, num_layers=l)
assert len(features) == l
for i in range(l):
self.assertEqual(all_features[i], features[i])
assert lengths_.shape == torch.Size([batch_size])
@factory_funcs
def test_batch_consistency(self, factory_func):
......
......@@ -377,17 +377,21 @@ class Transformer(Module):
self.dropout = nn.Dropout(dropout)
self.layers = layers
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
def _preprocess(self, x: Tensor):
x = x + self.pos_conv_embed(x)
if self.layer_norm_first:
x = self.layer_norm(x)
x = self.dropout(x)
return x
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
x = self._preprocess(x)
for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop):
x = layer(x, attention_mask)
......@@ -397,6 +401,25 @@ class Transformer(Module):
return x
def get_intermediate_outputs(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
if num_layers is not None:
if not 0 < num_layers <= len(self.layers):
raise ValueError(f'`num_layers` must be between [1, {len(self.layers)}]')
ret: List[Tensor] = []
x = self._preprocess(x)
for layer in self.layers:
x = layer(x, attention_mask)
ret.append(x)
if num_layers is not None and len(ret) >= num_layers:
return ret
return ret
class Encoder(Module):
def __init__(
......@@ -410,11 +433,11 @@ class Encoder(Module):
self.transformer = transformer
self.readout = readout
def forward(
def _preprocess(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
) -> Tuple[Tensor, Optional[Tensor]]:
x = self.feature_projection(features)
mask: Optional[Tensor] = None
......@@ -426,11 +449,28 @@ class Encoder(Module):
# extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
mask = mask.expand(batch_size, 1, max_len, max_len)
return x, mask
def forward(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask)
x = self.readout(x)
return x
def extract_features(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
x, masks = self._preprocess(features, lengths)
return self.transformer.get_intermediate_outputs(
x, attention_mask=masks, num_layers=num_layers)
################################################################################
def _get_feature_extractor(
......
from typing import Optional, Tuple, List
import torch
from torch import Tensor
from torch.nn import Module
......@@ -29,29 +30,41 @@ class Wav2Vec2Model(Module):
self.feature_extractor = feature_extractor
self.encoder = encoder
@torch.jit.export
def extract_features(
self,
waveforms: Tensor,
lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
"""Extract feature vectors from raw waveforms
This returns the list of outputs from the intermediate layers of
transformer block in encoder.
Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``.
num_layers (int or None, optional):
If given, limit the number of intermediate layers to go through.
Providing `1` will stop the computation after going through one
intermediate layers. If not given, the outputs from all the
intermediate layers are returned.
Returns:
Tensor:
Feature vectors.
List of Tensor:
Features from corresponding layers.
Shape: ``(batch, frames, feature dimention)``
Tensor, optional:
Indicates the valid length of each feature in the batch, computed
based on the given ``lengths`` argument.
Shape: ``(batch, )``.
"""
return self.feature_extractor(waveforms, lengths)
x, lengths = self.feature_extractor(waveforms, lengths)
x = self.encoder.extract_features(x, lengths, num_layers)
return x, lengths
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment