Unverified Commit 78b08c26 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Update `extract_features` of Wav2Vec2Model (#1776)

* [BC-Breaking] Update `extract_features` of Wav2Vec2Model

Originally, `extract_features` method was returning the result from
the convolutional feature extractor module.

The features commonly used in downstream tasks are outputs from intermediate
layers of transformer block in encoder.

This commit update the behavior of `extract_features` to allow selectively
retrieve such features.
parent 599a82b7
...@@ -86,9 +86,10 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -86,9 +86,10 @@ class TestFairseqIntegration(TorchaudioTestCase):
imported = import_fairseq_model(original, 28).eval() imported = import_fairseq_model(original, 28).eval()
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.extract_features(x) hyp, _ = imported.extract_features(x)
self.assertEqual(ref, hyp) refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs['layer_results']):
self.assertEqual(hyp[i], ref.transpose(0, 1))
@parameterized.expand(PRETRAINED_CONFIGS) @parameterized.expand(PRETRAINED_CONFIGS)
def test_recreate_pretrained_model(self, config, factory_func): def test_recreate_pretrained_model(self, config, factory_func):
......
...@@ -14,11 +14,16 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,11 +14,16 @@ from torchaudio_unittest.common_utils import (
) )
from parameterized import parameterized from parameterized import parameterized
def _name_func(testcase_func, _, param):
return f"{testcase_func.__name__}_{param[0][0].__name__}"
factory_funcs = parameterized.expand([ factory_funcs = parameterized.expand([
(wav2vec2_base, ), (wav2vec2_base, ),
(wav2vec2_large, ), (wav2vec2_large, ),
(wav2vec2_large_lv60k, ), (wav2vec2_large_lv60k, ),
]) ], name_func=_name_func)
class TestWav2Vec2Model(TorchaudioTestCase): class TestWav2Vec2Model(TorchaudioTestCase):
...@@ -47,20 +52,33 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -47,20 +52,33 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self._smoke_test(torch.device('cuda'), dtype) self._smoke_test(torch.device('cuda'), dtype)
@factory_funcs @factory_funcs
def test_feature_extractor_smoke_test(self, factory_func): def test_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail""" """`extract_features` method does not fail"""
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
model = factory_func(num_out=32).eval() model = factory_func(num_out=32).eval()
num_layers = len(model.encoder.transformer.layers)
torch.manual_seed(0) torch.manual_seed(0)
waveforms = torch.randn(batch_size, num_frames) waveforms = torch.randn(batch_size, num_frames)
lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ])
features, lengths = model.extract_features(waveforms, lengths) # Not providing num_layers returns all the intermediate features from
assert features.ndim == 3 # tranformer layers
assert features.shape[0] == batch_size all_features, lengths_ = model.extract_features(waveforms, lengths, num_layers=None)
assert lengths.shape == torch.Size([batch_size]) assert len(all_features) == num_layers
for features in all_features:
assert features.ndim == 3
assert features.shape[0] == batch_size
assert lengths_.shape == torch.Size([batch_size])
# Limiting the number of layers to `l`.
for l in range(1, num_layers + 1):
features, lengths_ = model.extract_features(waveforms, lengths, num_layers=l)
assert len(features) == l
for i in range(l):
self.assertEqual(all_features[i], features[i])
assert lengths_.shape == torch.Size([batch_size])
@factory_funcs @factory_funcs
def test_batch_consistency(self, factory_func): def test_batch_consistency(self, factory_func):
......
...@@ -377,17 +377,21 @@ class Transformer(Module): ...@@ -377,17 +377,21 @@ class Transformer(Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.layers = layers self.layers = layers
def forward( def _preprocess(self, x: Tensor):
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
x = x + self.pos_conv_embed(x) x = x + self.pos_conv_embed(x)
if self.layer_norm_first: if self.layer_norm_first:
x = self.layer_norm(x) x = self.layer_norm(x)
x = self.dropout(x) x = self.dropout(x)
return x
def forward(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
):
x = self._preprocess(x)
for layer in self.layers: for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop): if not (self.training and torch.rand(1).item() <= self.layer_drop):
x = layer(x, attention_mask) x = layer(x, attention_mask)
...@@ -397,6 +401,25 @@ class Transformer(Module): ...@@ -397,6 +401,25 @@ class Transformer(Module):
return x return x
def get_intermediate_outputs(
self,
x: Tensor,
attention_mask: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
if num_layers is not None:
if not 0 < num_layers <= len(self.layers):
raise ValueError(f'`num_layers` must be between [1, {len(self.layers)}]')
ret: List[Tensor] = []
x = self._preprocess(x)
for layer in self.layers:
x = layer(x, attention_mask)
ret.append(x)
if num_layers is not None and len(ret) >= num_layers:
return ret
return ret
class Encoder(Module): class Encoder(Module):
def __init__( def __init__(
...@@ -410,11 +433,11 @@ class Encoder(Module): ...@@ -410,11 +433,11 @@ class Encoder(Module):
self.transformer = transformer self.transformer = transformer
self.readout = readout self.readout = readout
def forward( def _preprocess(
self, self,
features: Tensor, features: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
) -> Tensor: ) -> Tuple[Tensor, Optional[Tensor]]:
x = self.feature_projection(features) x = self.feature_projection(features)
mask: Optional[Tensor] = None mask: Optional[Tensor] = None
...@@ -426,11 +449,28 @@ class Encoder(Module): ...@@ -426,11 +449,28 @@ class Encoder(Module):
# extend the mask to attention shape and set weight # extend the mask to attention shape and set weight
mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype)
mask = mask.expand(batch_size, 1, max_len, max_len) mask = mask.expand(batch_size, 1, max_len, max_len)
return x, mask
def forward(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
) -> Tensor:
x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask) x = self.transformer(x, attention_mask=mask)
x = self.readout(x) x = self.readout(x)
return x return x
def extract_features(
self,
features: Tensor,
lengths: Optional[Tensor] = None,
num_layers: Optional[int] = None,
) -> List[Tensor]:
x, masks = self._preprocess(features, lengths)
return self.transformer.get_intermediate_outputs(
x, attention_mask=masks, num_layers=num_layers)
################################################################################ ################################################################################
def _get_feature_extractor( def _get_feature_extractor(
......
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
...@@ -29,29 +30,41 @@ class Wav2Vec2Model(Module): ...@@ -29,29 +30,41 @@ class Wav2Vec2Model(Module):
self.feature_extractor = feature_extractor self.feature_extractor = feature_extractor
self.encoder = encoder self.encoder = encoder
@torch.jit.export
def extract_features( def extract_features(
self, self,
waveforms: Tensor, waveforms: Tensor,
lengths: Optional[Tensor] = None, lengths: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]: num_layers: Optional[int] = None,
) -> Tuple[List[Tensor], Optional[Tensor]]:
"""Extract feature vectors from raw waveforms """Extract feature vectors from raw waveforms
This returns the list of outputs from the intermediate layers of
transformer block in encoder.
Args: Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``. waveforms (Tensor): Audio tensor of shape ``(batch, frames)``.
lengths (Tensor or None, optional): lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch. Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``. Shape: ``(batch, )``.
num_layers (int or None, optional):
If given, limit the number of intermediate layers to go through.
Providing `1` will stop the computation after going through one
intermediate layers. If not given, the outputs from all the
intermediate layers are returned.
Returns: Returns:
Tensor: List of Tensor:
Feature vectors. Features from corresponding layers.
Shape: ``(batch, frames, feature dimention)`` Shape: ``(batch, frames, feature dimention)``
Tensor, optional: Tensor, optional:
Indicates the valid length of each feature in the batch, computed Indicates the valid length of each feature in the batch, computed
based on the given ``lengths`` argument. based on the given ``lengths`` argument.
Shape: ``(batch, )``. Shape: ``(batch, )``.
""" """
return self.feature_extractor(waveforms, lengths) x, lengths = self.feature_extractor(waveforms, lengths)
x = self.encoder.extract_features(x, lengths, num_layers)
return x, lengths
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment