Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
968ef515
Unverified
Commit
968ef515
authored
Apr 21, 2025
by
michael-amd
Committed by
GitHub
Apr 21, 2025
Browse files
Support aiter RMSNorm in AMD (#5510)
Co-authored-by:
JieXin Liang
<
Alcanderian@users.noreply.github.com
>
parent
13432002
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
3 deletions
+18
-3
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+18
-3
No files found.
python/sglang/srt/layers/layernorm.py
View file @
968ef515
...
@@ -20,9 +20,12 @@ import torch
...
@@ -20,9 +20,12 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
,
is_hip
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -32,8 +35,20 @@ if _is_cuda:
...
@@ -32,8 +35,20 @@ if _is_cuda:
rmsnorm
,
rmsnorm
,
)
)
if
_is_hip
:
logger
=
logging
.
getLogger
(
__name__
)
from
aiter.ops.rmsnorm
import
rms_norm
,
rmsnorm2d_fwd_with_add
rmsnorm
=
rms_norm
def
fused_add_rmsnorm
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
eps
:
float
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rmsnorm2d_fwd_with_add
(
x
,
x
,
residual
,
residual
,
w
,
eps
)
return
x
,
residual
class
RMSNorm
(
CustomOp
):
class
RMSNorm
(
CustomOp
):
...
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
...
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
if
not
_is_cuda
:
if
not
(
_is_cuda
or
_is_hip
)
:
logger
.
info
(
logger
.
info
(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment