Unverified Commit a24cb916 authored by qscqesze's avatar qscqesze Committed by GitHub
Browse files

[Model] Fix minimax model cache & lm_head precision (#19592)


Signed-off-by: default avatarqingjun <qingjun@minimaxi.com>
parent 7e8d97dd
...@@ -856,7 +856,7 @@ class MiniMaxText01Model(nn.Module): ...@@ -856,7 +856,7 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype self._dtype = _dummy.dtype
del _dummy del _dummy
self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
cache_shape=self.cache_shape) cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -1021,7 +1021,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1021,7 +1021,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.lm_head.float()
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
if attn_type == 1) if attn_type == 1)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
...@@ -1054,7 +1054,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1054,7 +1054,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states.float(),
sampling_metadata) sampling_metadata)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment