Unverified Commit 968ef515 authored by michael-amd's avatar michael-amd Committed by GitHub
Browse files

Support aiter RMSNorm in AMD (#5510)


Co-authored-by: default avatarJieXin Liang <Alcanderian@users.noreply.github.com>
parent 13432002
...@@ -20,9 +20,12 @@ import torch ...@@ -20,9 +20,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -32,8 +35,20 @@ if _is_cuda: ...@@ -32,8 +35,20 @@ if _is_cuda:
rmsnorm, rmsnorm,
) )
if _is_hip:
logger = logging.getLogger(__name__) from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
rmsnorm = rms_norm
def fused_add_rmsnorm(
x: torch.Tensor,
residual: torch.Tensor,
w: torch.Tensor,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
return x, residual
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
...@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module): ...@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}" return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not _is_cuda: if not (_is_cuda or _is_hip):
logger.info( logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment