Unverified Commit 40f2a085 authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Move fine-tune specific module out of wav2vec2 encoder (#1782)

Previously, the Linear module (called `readout`, which is used only for an ASR fine-tuning
task) was placed in encoder module. Conceptually, the encoder has nothing to
do with a module specific to fine-tuning / downstream task.

The problems here are that;
1. encoder can be also used in pre-training phase, in which such a module should
not present
2. The choice of Linear module is arbitral, and it is inconvenient for users
to have hard-coded module structure in encoder.

Therefore, this commit moves the Linear module out the encoder, and places it
as `aux` attribute of `Wav2Vec2Model`. (as a result `Wav2Vec2Model` has
`feature_extractor`, `encoder` and `aux` attributes.)

An alternative approach is to define another module and place `Wav2Vec2Model`
and aux module along each other. But that will introduce a new class we need
to maintain.
The expected use of `aux` is only  for 1. loading the pre-trained parameters 
published by `fairseq` (and it's variations from HF) and 2. creating the same model 
architectures for comparison experiment.
The newly introduced class will not be general enough for downstream adaptations, 
where there will be a bunch of different more complicated models. (i.e. s3prl)

Therefore, based on the minimalistic approach, we put them inside of `Wav2Vec2Model`.
parent e9cab8f8
...@@ -118,7 +118,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -118,7 +118,7 @@ class TestHFIntegration(TorchaudioTestCase):
# Readout # Readout
x = torch.randn(3, 10, config["hidden_size"]) x = torch.randn(3, 10, config["hidden_size"])
ref = original.lm_head(x) ref = original.lm_head(x)
hyp = imported.encoder.readout(x) hyp = imported.aux(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# The whole model without mask # The whole model without mask
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
...@@ -195,8 +195,8 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -195,8 +195,8 @@ class TestHFIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# Readout # Readout
x = torch.randn(3, 10, config["hidden_size"]) x = torch.randn(3, 10, config["hidden_size"])
ref = imported.encoder.readout(x) ref = imported.aux(x)
hyp = reloaded.encoder.readout(x) hyp = reloaded.aux(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# The whole model # The whole model
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
...@@ -208,7 +208,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -208,7 +208,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_pretrain(self, config, factory_func): def test_recreate_pretrain(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers.""" """Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval() imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features) reloaded = factory_func(num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
self._test_recreate(imported, reloaded, config) self._test_recreate(imported, reloaded, config)
...@@ -217,7 +217,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -217,7 +217,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_finetune(self, config, factory_func): def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers.""" """Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval() imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.encoder.readout.out_features) reloaded = factory_func(num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
self._test_recreate(imported, reloaded, config) self._test_recreate(imported, reloaded, config)
...@@ -426,12 +426,10 @@ class Encoder(Module): ...@@ -426,12 +426,10 @@ class Encoder(Module):
self, self,
feature_projection: Module, feature_projection: Module,
transformer: Module, transformer: Module,
readout: Module,
): ):
super().__init__() super().__init__()
self.feature_projection = feature_projection self.feature_projection = feature_projection
self.transformer = transformer self.transformer = transformer
self.readout = readout
def _preprocess( def _preprocess(
self, self,
...@@ -458,7 +456,6 @@ class Encoder(Module): ...@@ -458,7 +456,6 @@ class Encoder(Module):
) -> Tensor: ) -> Tensor:
x, mask = self._preprocess(features, lengths) x, mask = self._preprocess(features, lengths)
x = self.transformer(x, attention_mask=mask) x = self.transformer(x, attention_mask=mask)
x = self.readout(x)
return x return x
def extract_features( def extract_features(
...@@ -561,7 +558,6 @@ def _get_encoder( ...@@ -561,7 +558,6 @@ def _get_encoder(
dropout: float, dropout: float,
layer_norm_first: bool, layer_norm_first: bool,
layer_drop: float, layer_drop: float,
num_out: int,
) -> Encoder: ) -> Encoder:
""" """
Args: Args:
...@@ -720,8 +716,4 @@ def _get_encoder( ...@@ -720,8 +716,4 @@ def _get_encoder(
layer_norm_first=not layer_norm_first, layer_norm_first=not layer_norm_first,
layer_drop=layer_drop, layer_drop=layer_drop,
) )
readout = nn.Linear( return Encoder(feature_projection, transformer)
in_features=embed_dim,
out_features=num_out,
)
return Encoder(feature_projection, transformer, readout)
...@@ -20,15 +20,20 @@ class Wav2Vec2Model(Module): ...@@ -20,15 +20,20 @@ class Wav2Vec2Model(Module):
encoder (torch.nn.Module): encoder (torch.nn.Module):
Encoder that converts the audio features into the sequence of probability Encoder that converts the audio features into the sequence of probability
distribution (in negative log-likelihood) over labels. distribution (in negative log-likelihood) over labels.
aux (torch.nn.Module or None, optional):
Auxiliary module. If provided, the output from encoder is passed to this module.
""" """
def __init__( def __init__(
self, self,
feature_extractor: Module, feature_extractor: Module,
encoder: Module, encoder: Module,
aux: Optional[Module] = None,
): ):
super().__init__() super().__init__()
self.feature_extractor = feature_extractor self.feature_extractor = feature_extractor
self.encoder = encoder self.encoder = encoder
self.aux = aux
@torch.jit.export @torch.jit.export
def extract_features( def extract_features(
...@@ -89,7 +94,10 @@ class Wav2Vec2Model(Module): ...@@ -89,7 +94,10 @@ class Wav2Vec2Model(Module):
Shape: ``(batch, )``. Shape: ``(batch, )``.
""" """
x, lengths = self.feature_extractor(waveforms, lengths) x, lengths = self.feature_extractor(waveforms, lengths)
return self.encoder(x, lengths), lengths x = self.encoder(x, lengths)
if self.aux is not None:
x = self.aux(x)
return x, lengths
def _get_model( def _get_model(
...@@ -108,7 +116,7 @@ def _get_model( ...@@ -108,7 +116,7 @@ def _get_model(
encoder_dropout: float, encoder_dropout: float,
encoder_layer_norm_first: bool, encoder_layer_norm_first: bool,
encoder_layer_drop: float, encoder_layer_drop: float,
encoder_num_out: int, aux_num_out: int,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
if extractor_conv_layer_config is None: if extractor_conv_layer_config is None:
extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
...@@ -129,9 +137,12 @@ def _get_model( ...@@ -129,9 +137,12 @@ def _get_model(
dropout=encoder_dropout, dropout=encoder_dropout,
layer_norm_first=encoder_layer_norm_first, layer_norm_first=encoder_layer_norm_first,
layer_drop=encoder_layer_drop, layer_drop=encoder_layer_drop,
num_out=encoder_num_out,
) )
return Wav2Vec2Model(feature_extractor, encoder) aux = torch.nn.Linear(
in_features=encoder_embed_dim,
out_features=aux_num_out,
)
return Wav2Vec2Model(feature_extractor, encoder, aux)
def wav2vec2_base(num_out: int) -> Wav2Vec2Model: def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
...@@ -172,7 +183,7 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model: ...@@ -172,7 +183,7 @@ def wav2vec2_base(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.1, encoder_dropout=0.1,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.1, encoder_layer_drop=0.1,
encoder_num_out=num_out, aux_num_out=num_out,
) )
...@@ -214,7 +225,7 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model: ...@@ -214,7 +225,7 @@ def wav2vec2_large(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.1, encoder_dropout=0.1,
encoder_layer_norm_first=False, encoder_layer_norm_first=False,
encoder_layer_drop=0.1, encoder_layer_drop=0.1,
encoder_num_out=num_out, aux_num_out=num_out,
) )
...@@ -256,5 +267,5 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model: ...@@ -256,5 +267,5 @@ def wav2vec2_large_lv60k(num_out: int) -> Wav2Vec2Model:
encoder_dropout=0.0, encoder_dropout=0.0,
encoder_layer_norm_first=True, encoder_layer_norm_first=True,
encoder_layer_drop=0.1, encoder_layer_drop=0.1,
encoder_num_out=num_out, aux_num_out=num_out,
) )
...@@ -46,7 +46,7 @@ def _parse_config(w2v_model, num_out): ...@@ -46,7 +46,7 @@ def _parse_config(w2v_model, num_out):
'encoder_dropout': encoder.layers[0].dropout3.p, 'encoder_dropout': encoder.layers[0].dropout3.p,
'encoder_layer_norm_first': encoder.layer_norm_first, 'encoder_layer_norm_first': encoder.layer_norm_first,
'encoder_layer_drop': encoder.layerdrop, 'encoder_layer_drop': encoder.layerdrop,
'encoder_num_out': num_out, 'aux_num_out': num_out,
} }
return config return config
...@@ -110,7 +110,7 @@ def _map_key(key): ...@@ -110,7 +110,7 @@ def _map_key(key):
match = re.match(r"proj\.(weight|bias)", key) match = re.match(r"proj\.(weight|bias)", key)
# Encoder - Readout layer # Encoder - Readout layer
if match: if match:
return f"encoder.readout.{match.group(1)}" return f"aux.{match.group(1)}"
raise ValueError(f'Unexpected key: {key_}') raise ValueError(f'Unexpected key: {key_}')
......
...@@ -26,7 +26,7 @@ def _get_config(cfg): ...@@ -26,7 +26,7 @@ def _get_config(cfg):
'encoder_dropout': cfg.hidden_dropout, 'encoder_dropout': cfg.hidden_dropout,
'encoder_layer_norm_first': cfg.do_stable_layer_norm, 'encoder_layer_norm_first': cfg.do_stable_layer_norm,
'encoder_layer_drop': cfg.layerdrop, 'encoder_layer_drop': cfg.layerdrop,
'encoder_num_out': cfg.vocab_size, 'aux_num_out': cfg.vocab_size,
} }
return config return config
...@@ -42,7 +42,7 @@ def _build(config, original): ...@@ -42,7 +42,7 @@ def _build(config, original):
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
if original.__class__.__name__ == 'Wav2Vec2ForCTC': if original.__class__.__name__ == 'Wav2Vec2ForCTC':
imported.encoder.readout.load_state_dict(original.lm_head.state_dict()) imported.aux.load_state_dict(original.lm_head.state_dict())
return imported return imported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment