"tests/vscode:/vscode.git/clone" did not exist on "5d5fd243da84beea19ecb66cbce0fdfa449f5a33"
Unverified Commit f5d412ba authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[BugFix] Fix regression caused by mamba state dtype PR (#22998)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 177e55e3
......@@ -650,8 +650,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata
......
......@@ -767,8 +767,12 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP,
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
self.vllm_config,
num_mamba_layers,
*self._get_mamba_cache_shape(),
self.lm_head.weight.dtype,
self.lm_head.weight.dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment