Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e2ac2e6
Unverified
Commit
8e2ac2e6
authored
Oct 30, 2025
by
Makcum888e
Committed by
GitHub
Oct 30, 2025
Browse files
[NPU] fix pp_size>1 (#12195)
parent
17a57fd8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
23 deletions
+21
-23
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+9
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+2
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+7
-15
No files found.
python/sglang/srt/distributed/parallel_state.py
View file @
8e2ac2e6
...
@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
...
@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
@
dataclass
@
dataclass
class
GraphCaptureContext
:
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
if
not
_is_npu
else
torch
.
npu
.
Stream
stream
:
torch
.
get_device_module
()
.
Stream
@
dataclass
@
dataclass
...
@@ -498,7 +498,7 @@ class GroupCoordinator:
...
@@ -498,7 +498,7 @@ class GroupCoordinator:
maybe_pynccl_context
=
nullcontext
()
maybe_pynccl_context
=
nullcontext
()
else
:
else
:
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()
enable
=
True
,
stream
=
torch
.
get_device_module
()
.
current_stream
()
)
)
pymscclpp_comm
=
self
.
pymscclpp_comm
pymscclpp_comm
=
self
.
pymscclpp_comm
...
@@ -555,7 +555,7 @@ class GroupCoordinator:
...
@@ -555,7 +555,7 @@ class GroupCoordinator:
and
input_
.
symmetric_memory
and
input_
.
symmetric_memory
):
):
with
self
.
pynccl_comm
.
change_state
(
with
self
.
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()
enable
=
True
,
stream
=
torch
.
get_device_module
()
.
current_stream
()
):
):
self
.
pynccl_comm
.
all_reduce
(
input_
)
self
.
pynccl_comm
.
all_reduce
(
input_
)
return
input_
return
input_
...
@@ -655,7 +655,9 @@ class GroupCoordinator:
...
@@ -655,7 +655,9 @@ class GroupCoordinator:
world_size
=
self
.
world_size
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()):
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
get_device_module
().
current_stream
()
):
assert
(
assert
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
),
"pynccl is required for reduce_scatterv"
),
"pynccl is required for reduce_scatterv"
...
@@ -779,7 +781,9 @@ class GroupCoordinator:
...
@@ -779,7 +781,9 @@ class GroupCoordinator:
world_size
=
self
.
world_size
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()):
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
get_device_module
().
current_stream
()
):
assert
(
assert
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
),
"pynccl is required for all_gatherv"
),
"pynccl is required for all_gatherv"
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
8e2ac2e6
...
@@ -1137,10 +1137,10 @@ class AscendTokenToKVPool(MHATokenToKVPool):
...
@@ -1137,10 +1137,10 @@ class AscendTokenToKVPool(MHATokenToKVPool):
torch_npu
.
_npu_reshape_and_cache
(
torch_npu
.
_npu_reshape_and_cache
(
key
=
cache_k
,
key
=
cache_k
,
value
=
cache_v
,
value
=
cache_v
,
key_cache
=
self
.
k_buffer
[
layer_id
].
view
(
key_cache
=
self
.
k_buffer
[
layer_id
-
self
.
start_layer
].
view
(
-
1
,
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
-
1
,
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
),
),
value_cache
=
self
.
v_buffer
[
layer_id
].
view
(
value_cache
=
self
.
v_buffer
[
layer_id
-
self
.
start_layer
].
view
(
-
1
,
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
-
1
,
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
),
),
slot_indices
=
loc
,
slot_indices
=
loc
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8e2ac2e6
...
@@ -1659,9 +1659,11 @@ class ModelRunner:
...
@@ -1659,9 +1659,11 @@ class ModelRunner:
get_attention_tp_size
()
get_attention_tp_size
()
),
),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden
_layers
,
layer_num
=
self
.
num_effective
_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
)
elif
self
.
use_mla_backend
and
is_nsa_model
:
elif
self
.
use_mla_backend
and
is_nsa_model
:
self
.
token_to_kv_pool
=
NSATokenToKVPool
(
self
.
token_to_kv_pool
=
NSATokenToKVPool
(
...
...
python/sglang/srt/utils/common.py
View file @
8e2ac2e6
...
@@ -1239,42 +1239,34 @@ def point_to_point_pyobj(
...
@@ -1239,42 +1239,34 @@ def point_to_point_pyobj(
dst
:
int
=
1
,
dst
:
int
=
1
,
):
):
"""Send data from src to dst in group using DeviceToDevice communication."""
"""Send data from src to dst in group using DeviceToDevice communication."""
device
=
torch
.
get_device_module
().
current_device
()
if
rank
==
src
:
if
rank
==
src
:
if
len
(
data
)
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
[
0
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
else
:
else
:
serialized_data
=
pickle
.
dumps
(
data
)
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
tensor_data
=
torch
.
ByteTensor
(
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
).
cuda
(
).
to
(
device
=
torch
.
cuda
.
current_
device
()
device
=
device
)
# Move to GPU
)
# Move to GPU
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
,
device
=
device
)
[
size
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
return
data
return
data
elif
rank
==
dst
:
elif
rank
==
dst
:
tensor_size
=
torch
.
tensor
(
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
[
0
],
dtype
=
torch
.
long
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
size
=
tensor_size
.
item
()
size
=
tensor_size
.
item
()
if
size
==
0
:
if
size
==
0
:
return
[]
return
[]
tensor_data
=
torch
.
empty
(
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
,
device
=
device
)
size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
()
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
serialized_data
=
bytes
(
serialized_data
=
bytes
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment