"vscode:/vscode.git/clone" did not exist on "cb2bc69b5538bf1f3d1b999987b28b5d1ac90ba0"
Unverified Commit 968ef515 authored by michael-amd's avatar michael-amd Committed by GitHub
Browse files

Support aiter RMSNorm in AMD (#5510)


Co-authored-by: default avatarJieXin Liang <Alcanderian@users.noreply.github.com>
parent 13432002
......@@ -20,9 +20,12 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import (
......@@ -32,8 +35,20 @@ if _is_cuda:
rmsnorm,
)
if _is_hip:
logger = logging.getLogger(__name__)
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
rmsnorm = rms_norm
def fused_add_rmsnorm(
x: torch.Tensor,
residual: torch.Tensor,
w: torch.Tensor,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
return x, residual
class RMSNorm(CustomOp):
......@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not _is_cuda:
if not (_is_cuda or _is_hip):
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment