Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5dd8c644
Unverified
Commit
5dd8c644
authored
Sep 08, 2025
by
ssshinigami
Committed by
GitHub
Sep 08, 2025
Browse files
[Bug fix] Fix Gemma 2 and fix Gemma 3 multimodal with bs > 1 on NPU (#9871)
Co-authored-by:
Maksim
<
makcum888e@mail.ru
>
parent
ee21817c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
9 deletions
+11
-9
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+1
-6
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+9
-2
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+1
-1
No files found.
python/sglang/srt/layers/layernorm.py
View file @
5dd8c644
...
@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp):
...
@@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
orig_dtype
=
x
.
dtype
if
residual
is
not
None
:
if
residual
is
not
None
:
x
=
x
+
residual
x
=
x
+
residual
residual
=
x
residual
=
x
x
=
x
.
float
()
x
,
_
=
torch_npu
.
npu_gemma_rms_norm
(
x
,
self
.
weight
,
self
.
variance_epsilon
)
variance
=
torch_npu
.
mean
(
torch_npu
.
pow
(
x
,
2
),
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch_npu
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
(
1.0
+
self
.
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
if
residual
is
None
else
(
x
,
residual
)
return
x
if
residual
is
None
else
(
x
,
residual
)
...
...
python/sglang/srt/layers/logits_processor.py
View file @
5dd8c644
...
@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.utils
import
dump_to_file
,
use_intel_amx_backend
from
sglang.srt.utils
import
dump_to_file
,
is_npu
,
use_intel_amx_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_npu
=
is_npu
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
LogitsProcessorOutput
:
class
LogitsProcessorOutput
:
...
@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module):
...
@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module):
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
logits
=
logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
self
.
final_logit_softcapping
:
if
self
.
final_logit_softcapping
:
fused_softcap
(
logits
,
self
.
final_logit_softcapping
)
if
not
_is_npu
:
fused_softcap
(
logits
,
self
.
final_logit_softcapping
)
else
:
logits
=
self
.
final_logit_softcapping
*
torch
.
tanh
(
logits
/
self
.
final_logit_softcapping
)
return
logits
return
logits
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
5dd8c644
...
@@ -550,7 +550,7 @@ class PrefillAdder:
...
@@ -550,7 +550,7 @@ class PrefillAdder:
)
)
else
:
else
:
# Make sure at least one page is available
# Make sure at least one page is available
trunc_len
=
self
.
rem_chunk_tokens
-
self
.
page_size
+
1
trunc_len
=
self
.
rem_chunk_tokens
//
self
.
page_size
*
self
.
page_size
if
trunc_len
<=
0
:
if
trunc_len
<=
0
:
return
AddReqResult
.
OTHER
return
AddReqResult
.
OTHER
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment