Unverified Commit e9fda392 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

remove F.rms_norm for now (#11126)

up
parent 2c1ed50f
...@@ -550,16 +550,6 @@ class RMSNorm(nn.Module): ...@@ -550,16 +550,6 @@ class RMSNorm(nn.Module):
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
if self.bias is not None: if self.bias is not None:
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
elif is_torch_version(">=", "2.4"):
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = nn.functional.rms_norm(
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
)
if self.bias is not None:
hidden_states = hidden_states + self.bias
else: else:
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment