Unverified Commit 9e00566b authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Add Wav2Vec2 Adapter Weights to Flax (#15566)

* Add Wav2Vec2 Adapter Weights to Flax

* Suggested changes
parent 1f60bc46
...@@ -766,6 +766,73 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module): ...@@ -766,6 +766,73 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
return codevectors, perplexity return codevectors, perplexity
class FlaxWav2Vec2Adapter(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
# hidden_states require down-projection if feature dims don't match
if self.config.output_hidden_size != self.config.hidden_size:
self.proj = nn.Dense(
self.config.output_hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
else:
self.proj = self.proj_layer_norm = None
self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
# down-project hidden_states if required
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)
hidden_states = self.layers(hidden_states)
return hidden_states
class FlaxWav2Vec2AdapterLayer(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
features=2 * self.config.output_hidden_size,
kernel_size=(self.config.adapter_kernel_size,),
strides=(self.config.adapter_stride,),
padding=((1, 1),),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.glu(hidden_states, axis=2)
return hidden_states
class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_adapter_layers)
]
def __call__(self, hidden_states):
for conv_layer in self.layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
...@@ -840,7 +907,9 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): ...@@ -840,7 +907,9 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
rngs=rngs, rngs=rngs,
) )
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths) return self.module._get_feat_extract_output_lengths(input_lengths)
...@@ -860,6 +929,8 @@ class FlaxWav2Vec2Module(nn.Module): ...@@ -860,6 +929,8 @@ class FlaxWav2Vec2Module(nn.Module):
else: else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
def __call__( def __call__(
self, self,
input_values, input_values,
...@@ -905,6 +976,9 @@ class FlaxWav2Vec2Module(nn.Module): ...@@ -905,6 +976,9 @@ class FlaxWav2Vec2Module(nn.Module):
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict: if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:] return (hidden_states, extract_features) + encoder_outputs[1:]
...@@ -915,11 +989,15 @@ class FlaxWav2Vec2Module(nn.Module): ...@@ -915,11 +989,15 @@ class FlaxWav2Vec2Module(nn.Module):
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers
""" """
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride): def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken # 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
...@@ -928,6 +1006,10 @@ class FlaxWav2Vec2Module(nn.Module): ...@@ -928,6 +1006,10 @@ class FlaxWav2Vec2Module(nn.Module):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths return input_lengths
...@@ -1021,11 +1103,17 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): ...@@ -1021,11 +1103,17 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): def _get_feat_extract_output_lengths(
self,
input_lengths: Union[jnp.ndarray, int],
add_adapter: Optional[bool] = None,
):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers
""" """
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride): def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken # 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
...@@ -1034,6 +1122,10 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): ...@@ -1034,6 +1122,10 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths return input_lengths
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment