Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4a6e7a66
"vscode:/vscode.git/clone" did not exist on "fee64827b7bf3be8cf99c3c1b772e72fecadd4a6"
Unverified
Commit
4a6e7a66
authored
Aug 01, 2025
by
kk
Committed by
GitHub
Jul 31, 2025
Browse files
Fix nan value generated after custom all reduce (#8532)
parent
4b04998d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
...srt/distributed/device_communicators/custom_all_reduce.py
+7
-7
No files found.
python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
View file @
4a6e7a66
...
@@ -184,7 +184,7 @@ class CustomAllreduce:
...
@@ -184,7 +184,7 @@ class CustomAllreduce:
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
self
.
rank_data
=
torch
.
zeros
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
_ptr
=
ops
.
init_custom_ar
(
...
@@ -194,14 +194,14 @@ class CustomAllreduce:
...
@@ -194,14 +194,14 @@ class CustomAllreduce:
else
:
else
:
# meta data buffers need to be "uncached" for signal on MI200
# meta data buffers need to be "uncached" for signal on MI200
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
meta
=
ops
.
allocate_meta_buffer
(
ops
.
meta_size
()
+
max_size
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
buffer
=
torch
.
zeros
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
handle
=
ops
.
get_meta_buffer_ipc_handle
(
self
.
meta
)
shard_data
=
(
shard_data
=
(
bytes
(
handle
),
# ipc handle to base ptr
bytes
(
handle
),
# ipc handle to base ptr
0
,
# offset of base ptr
0
,
# offset of base ptr
)
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
handles
,
offsets
=
self
.
_gather_ipc_meta
(
shard_data
)
self
.
rank_data
=
torch
.
empty
(
self
.
rank_data
=
torch
.
zeros
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
_ptr
=
ops
.
init_custom_ar
(
...
@@ -350,14 +350,14 @@ class CustomAllreduce:
...
@@ -350,14 +350,14 @@ class CustomAllreduce:
# or, in the context of cuda graphs, register_graph_buffers
# or, in the context of cuda graphs, register_graph_buffers
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
_like
(
inp
)
out
=
torch
.
zeros
_like
(
inp
)
ops
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
ops
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
return
out
return
out
# all reduce, assuming inp tensor is NOT IPC registered
# all reduce, assuming inp tensor is NOT IPC registered
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
_like
(
inp
)
out
=
torch
.
zeros
_like
(
inp
)
ops
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
ops
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
return
out
...
@@ -375,7 +375,7 @@ class CustomAllreduce:
...
@@ -375,7 +375,7 @@ class CustomAllreduce:
buffer.
buffer.
"""
"""
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
_like
(
inp
)
out
=
torch
.
zeros
_like
(
inp
)
if
registered
:
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
else
:
...
@@ -398,7 +398,7 @@ class CustomAllreduce:
...
@@ -398,7 +398,7 @@ class CustomAllreduce:
else
:
else
:
# If warm up, mimic the allocation pattern since custom
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
# allreduce is out-of-place.
return
torch
.
empty
_like
(
input
)
return
torch
.
zeros
_like
(
input
)
else
:
else
:
if
_is_hip
:
if
_is_hip
:
# note: outside of cuda graph context,
# note: outside of cuda graph context,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment