Unverified Commit 6997a25a authored by qscqesze's avatar qscqesze Committed by GitHub
Browse files

[Model] Remove useless code from MiniMax implementation (#23982)


Signed-off-by: default avatarQscQ <qscqesze@gmail.com>
Signed-off-by: default avatarqingjun <qingjun@minimaxi.com>
parent 28f350e1
...@@ -83,17 +83,7 @@ class MiniMaxText01RMSNormTP(CustomOp): ...@@ -83,17 +83,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
variance = tensor_model_parallel_all_reduce( variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]
x = x.to(orig_dtype) * weight
return x return x
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment