"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "3d2a2de8f75b973927d63b4cab63d9abb1a24722"
Unverified Commit c4e601c7 authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

Bugfix: Parakeet: `.conv.pointwise/depthwise_conv1/2.bias weigths` can exist...


Bugfix: Parakeet: `.conv.pointwise/depthwise_conv1/2.bias weigths` can exist even if `convolution_bias=False` (#40007)
Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent 29057d3b
...@@ -99,6 +99,8 @@ class ProjectedParakeet(nn.Module): ...@@ -99,6 +99,8 @@ class ProjectedParakeet(nn.Module):
if target is None: if target is None:
target = buffers_dict.get(target_name) target = buffers_dict.get(target_name)
if target is None: if target is None:
if self._can_skip_missing_named_param(target_name):
continue
raise ValueError(f"Unknown weight: {name}") raise ValueError(f"Unknown weight: {name}")
weight_loader = getattr(target, "weight_loader", default_weight_loader) weight_loader = getattr(target, "weight_loader", default_weight_loader)
with torch.no_grad(): with torch.no_grad():
...@@ -107,6 +109,27 @@ class ProjectedParakeet(nn.Module): ...@@ -107,6 +109,27 @@ class ProjectedParakeet(nn.Module):
return loaded_params return loaded_params
def _can_skip_missing_named_param(self, target_name: str) -> bool:
if self.config.convolution_bias:
return False
# In transformers v5 (not v4), `convolution_bias=False` is
# propagated from parakeet config. If `False`, torch.conv1d will
# *skip registering the param*, thus it will be missing in the
# module's named params. *If* you happen to also have the bias
# tensors in the weights, it will cause a mismatch between the
# weights and the params.
# This allows us to have `convolution_bias=False` in the sound config,
# but still allow for the weights to exist.
return target_name.endswith(
(
".conv.pointwise_conv1.bias",
".conv.depthwise_conv.bias",
".conv.pointwise_conv2.bias",
)
)
EPSILON = 1e-5 EPSILON = 1e-5
LOG_ZERO_GUARD_VALUE = 2**-24 LOG_ZERO_GUARD_VALUE = 2**-24
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment