"docs/source/vscode:/vscode.git/clone" did not exist on "358fc18613a737f3fbcebc5b2abed43386ff9cbc"
Unverified Commit 6adefba3 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxSpeechEncoderDecoder] Fix input shape bug in weights init (#16728)

* [FlaxSpeechEncoderDecoder] Fix input shape bug in weights init

* make style
parent 1bac40db
......@@ -226,11 +226,15 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
else:
self.enc_to_dec_proj = None
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
......@@ -239,6 +243,10 @@ class FlaxSpeechEncoderDecoderModule(nn.Module):
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.encoder.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
return input_lengths
def _get_encoder_module(self):
......@@ -432,8 +440,10 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
)
return unfreeze(init_variables["cache"])
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
return self.module._get_feat_extract_output_lengths(input_lengths)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment