Unverified Commit b0feda09 authored by HAI's avatar HAI Committed by GitHub
Browse files

Revert "Support aiter RMSNorm in AMD" (#5646)

parent 6b6e7487
......@@ -20,12 +20,9 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
from sgl_kernel import (
......@@ -35,20 +32,8 @@ if _is_cuda:
rmsnorm,
)
if _is_hip:
from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
rmsnorm = rms_norm
def fused_add_rmsnorm(
x: torch.Tensor,
residual: torch.Tensor,
w: torch.Tensor,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
return x, residual
logger = logging.getLogger(__name__)
class RMSNorm(CustomOp):
......@@ -154,7 +139,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not (_is_cuda or _is_hip):
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment