Unverified Commit aeee7ef9 authored by Rishapveer Singh's avatar Rishapveer Singh Committed by GitHub
Browse files

[Bugfix] Fix k_proj's bias for GLM-ASR (#40160)


Signed-off-by: default avatarRishapveer Singh <singhrishapveer@gmail.com>
parent cda19ecf
...@@ -66,7 +66,7 @@ from .interfaces import ( ...@@ -66,7 +66,7 @@ from .interfaces import (
SupportsTranscription, SupportsTranscription,
) )
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .whisper import ISO639_1_SUPPORTED_LANGS from .whisper import ISO639_1_SUPPORTED_LANGS, _create_fake_bias_for_k_proj
class GlmAsrEncoderRotaryEmbedding(nn.Module): class GlmAsrEncoderRotaryEmbedding(nn.Module):
...@@ -499,6 +499,8 @@ class GlmAsrEncoder(nn.Module): ...@@ -499,6 +499,8 @@ class GlmAsrEncoder(nn.Module):
"""Custom weight loading to handle q_proj/k_proj/v_proj -> qkv_proj mapping.""" """Custom weight loading to handle q_proj/k_proj/v_proj -> qkv_proj mapping."""
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment