Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
93c6fb12
Unverified
Commit
93c6fb12
authored
Apr 25, 2025
by
michael-amd
Committed by
GitHub
Apr 25, 2025
Browse files
Fix: deepseek forward absorb (#5723)
Co-authored-by:
ispobock
<
ispobaoke@163.com
>
parent
11e27d09
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
4 deletions
+41
-4
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+41
-4
No files found.
python/sglang/srt/layers/layernorm.py
View file @
93c6fb12
...
@@ -20,9 +20,10 @@ import torch
...
@@ -20,9 +20,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -32,6 +33,8 @@ if _is_cuda:
...
@@ -32,6 +33,8 @@ if _is_cuda:
rmsnorm
,
rmsnorm
,
)
)
if
_is_hip
:
from
vllm._custom_ops
import
fused_add_rms_norm
,
rms_norm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -46,23 +49,49 @@ class RMSNorm(CustomOp):
...
@@ -46,23 +49,49 @@ class RMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
*
args
,
**
kwargs
):
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
*
args
,
**
kwargs
)
if
_is_cuda
:
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
elif
_is_hip
:
return
self
.
forward_hip
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
return
x
,
residual
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
return
out
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
not
x
.
is_contiguous
():
# NOTE: Romove this if aiter kernel supports discontinuous input
x
=
x
.
contiguous
()
if
residual
is
not
None
:
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
if
residual
is
not
None
:
...
@@ -88,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
...
@@ -88,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
*
args
,
**
kwargs
):
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
*
args
,
**
kwargs
)
if
_is_cuda
:
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -139,8 +176,8 @@ class Gemma3RMSNorm(nn.Module):
...
@@ -139,8 +176,8 @@ class Gemma3RMSNorm(nn.Module):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
if
not
_is_cuda
:
if
not
(
_is_cuda
or
_is_hip
)
:
logger
.
info
(
logger
.
info
(
"sgl-kernel is not available on
Non-NV
platform
s
. Fallback to other kernel libraries."
"sgl-kernel
layernorm implementation
is not available on
current
platform. Fallback to other kernel libraries."
)
)
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
,
RMSNorm
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
,
RMSNorm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment