"...python/git@developer.sourcefind.cn:change/sglang.git" did not exist on "6840a7bbb2e6f3c5b00967f02d908648f9bd72fb"
Unverified Commit 8d114f25 authored by sogalin's avatar sogalin Committed by GitHub
Browse files

Fix RMSNorm API CALL mismatch issue. (#10032)


Co-authored-by: default avatarHubert Lu <Hubert.Lu@amd.com>
parent 0e78c63c
...@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union ...@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging.version import Version
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -49,8 +50,11 @@ if _use_aiter: ...@@ -49,8 +50,11 @@ if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
elif _is_hip: elif _is_hip:
import vllm
from vllm._custom_ops import fused_add_rms_norm, rms_norm from vllm._custom_ops import fused_add_rms_norm, rms_norm
_vllm_version = Version(vllm.__version__)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if _is_npu: if _is_npu:
...@@ -127,8 +131,21 @@ class RMSNorm(CustomOp): ...@@ -127,8 +131,21 @@ class RMSNorm(CustomOp):
# NOTE: Remove this if aiter kernel supports discontinuous input # NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous() x = x.contiguous()
if residual is not None: if residual is not None:
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) if _vllm_version < Version("0.9"):
return x, residual fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
else:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual_out,
residual,
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
out = torch.empty_like(x) out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon) rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment