"vscode:/vscode.git/clone" did not exist on "7ee5d5093b369d5c55199bc4613c9afdecabe0b7"
Unverified Commit d6704dd0 authored by Roger Young's avatar Roger Young Committed by GitHub
Browse files

Fix MiniMax-M2 rmsnorm precision and remove useless code (#27627)


Signed-off-by: default avatarxuebi <xuebi@minimaxi.com>
Co-authored-by: default avatarxuebi <xuebi@minimaxi.com>
parent ecca3fee
...@@ -77,7 +77,7 @@ class MiniMaxText01RMSNormTP(CustomOp): ...@@ -77,7 +77,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
if self.tp_world > 1: if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight x = (x * self.weight).to(orig_dtype)
return x return x
def forward( def forward(
......
...@@ -263,23 +263,6 @@ class MiniMaxM2DecoderLayer(nn.Module): ...@@ -263,23 +263,6 @@ class MiniMaxM2DecoderLayer(nn.Module):
# with the layer's index. # with the layer's index.
layer_idx = int(prefix.split(sep=".")[-1]) layer_idx = int(prefix.split(sep=".")[-1])
# TODO: support MTP
attn_window_size = getattr(config, "attn_window_size", None)
if attn_window_size is not None:
if isinstance(attn_window_size, list):
attn_window_size = attn_window_size[layer_idx]
elif isinstance(attn_window_size, int):
attn_window_size = attn_window_size
else:
raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
attn_window_size = None if attn_window_size <= 0 else attn_window_size
# different rope theta for full layer and swa layer
swa_rope_theta = getattr(config, "swa_rope_theta", -1)
# default to full rope theta
swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.self_attn = MiniMaxM2Attention( self.self_attn = MiniMaxM2Attention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -288,7 +271,6 @@ class MiniMaxM2DecoderLayer(nn.Module): ...@@ -288,7 +271,6 @@ class MiniMaxM2DecoderLayer(nn.Module):
rotary_dim=config.rotary_dim, rotary_dim=config.rotary_dim,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
attn_window_size=attn_window_size,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False), qkv_bias=getattr(config, "attention_bias", False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment