Unverified Commit 357921aa authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Fix: Minicpm (#7612)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent c071198c
......@@ -32,7 +32,7 @@ from transformers.activations import ACT2FN
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.models.whisper.modeling_whisper import (
WHISPER_ATTENTION_CLASSES,
WhisperAttention,
WhisperConfig,
WhisperEncoder,
)
......@@ -1090,7 +1090,7 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment