Unverified Commit ad233189 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Bugfix] Fixup Mamba (#10004)


Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent bbc3619d
......@@ -39,8 +39,8 @@ class MambaDecoderLayer(nn.Module):
super().__init__()
self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba"
mixer_rms_rps = config.mixer_rms_rps if self.is_falcon_mamba else None
self.mamba = MambaMixer(hidden_size=config.hidden_size,
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=config.intermediate_size,
......@@ -48,7 +48,7 @@ class MambaDecoderLayer(nn.Module):
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
use_rms_norm=self.is_falcon_mamba,
rms_norm_eps=mixer_rms_rps,
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -99,7 +99,6 @@ class MambaModel(nn.Module):
for i in range(config.num_hidden_layers):
decoder_layers.append(
MambaDecoderLayer(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment