Unverified Commit cecada52 authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

NPU adaption for RMSNorm (#10534)



* NPU adaption for RMSNorm

* NPU adaption for RMSNorm

---------
Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
parent 17d99c4d
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import is_torch_version from ..utils import is_torch_npu_available, is_torch_version
from .activations import get_activation from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
...@@ -505,19 +505,30 @@ class RMSNorm(nn.Module): ...@@ -505,19 +505,30 @@ class RMSNorm(nn.Module):
self.bias = nn.Parameter(torch.zeros(dim)) self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype if is_torch_npu_available():
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) import torch_npu
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
if self.weight is not None: # convert into half-precision if necessary
# convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]:
if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
hidden_states = hidden_states * self.weight
if self.bias is not None: if self.bias is not None:
hidden_states = hidden_states + self.bias hidden_states = hidden_states + self.bias
else: else:
hidden_states = hidden_states.to(input_dtype) input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment