Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d25398cb
Unverified
Commit
d25398cb
authored
May 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
May 06, 2025
Browse files
fix custom_allreduce namespace (#6039)
parent
8a828666
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
24 deletions
+24
-24
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
+9
-9
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
+3
-3
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
+8
-8
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
+4
-4
No files found.
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
View file @
d25398cb
...
@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat
...
@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat
if
(
world_size
%
2
!=
0
)
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
if
(
world_size
%
2
!=
0
)
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
if
(
rank
<
0
||
rank
>=
world_size
)
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
if
(
rank
<
0
||
rank
>=
world_size
)
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
vllm
::
Signal
*
ipc_ptrs
[
8
];
sglang
::
Signal
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
vllm
::
Signal
*>
(
fake_ipc_ptrs
[
i
]);
ipc_ptrs
[
i
]
=
reinterpret_cast
<
sglang
::
Signal
*>
(
fake_ipc_ptrs
[
i
]);
}
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
return
(
fptr_t
)
new
sglang
::
CustomAllreduce
(
ipc_ptrs
,
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
rank
,
world_size
,
full_nvlink
);
ipc_ptrs
,
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
rank
,
world_size
,
full_nvlink
);
}
}
...
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
...
@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* copied into _reg_buffer.
* copied into _reg_buffer.
*/
*/
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
sglang
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_
...
@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_
}
}
void
dispose
(
fptr_t
_fa
)
{
void
dispose
(
fptr_t
_fa
)
{
delete
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
delete
reinterpret_cast
<
sglang
::
CustomAllreduce
*>
(
_fa
);
}
}
int64_t
meta_size
()
{
int64_t
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
return
sizeof
(
sglang
::
Signal
);
}
}
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
sglang
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
void
*
ipc_ptrs
[
8
];
void
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
...
@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
...
@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
// Use vector<int64_t> to represent byte data for python binding compatibility.
// Use vector<int64_t> to represent byte data for python binding compatibility.
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
sglang
::
CustomAllreduce
*>
(
_fa
);
auto
[
handle
,
offsets
]
=
fa
->
get_graph_buffer_ipc_meta
();
auto
[
handle
,
offsets
]
=
fa
->
get_graph_buffer_ipc_meta
();
std
::
vector
<
int64_t
>
bytes
(
handle
.
begin
(),
handle
.
end
());
std
::
vector
<
int64_t
>
bytes
(
handle
.
begin
(),
handle
.
end
());
return
std
::
make_tuple
(
bytes
,
offsets
);
return
std
::
make_tuple
(
bytes
,
offsets
);
...
@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
...
@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
// Use vector<int64_t> to represent byte data for python binding compatibility.
// Use vector<int64_t> to represent byte data for python binding compatibility.
void
register_graph_buffers
(
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
sglang
::
CustomAllreduce
*>
(
_fa
);
std
::
vector
<
std
::
string
>
bytes
;
std
::
vector
<
std
::
string
>
bytes
;
bytes
.
reserve
(
handles
.
size
());
bytes
.
reserve
(
handles
.
size
());
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
...
...
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
View file @
d25398cb
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "utils.h"
#include "utils.h"
namespace
vllm
{
namespace
sglang
{
constexpr
int
kMaxBlocks
=
36
;
constexpr
int
kMaxBlocks
=
36
;
// Counter may overflow, but it's fine since unsigned int overflow is
// Counter may overflow, but it's fine since unsigned int overflow is
...
@@ -483,7 +483,7 @@ class CustomAllreduce {
...
@@ -483,7 +483,7 @@ class CustomAllreduce {
/**
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
a template instantiation:
* template void
vllm
::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
* template void
sglang
::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
half *, int, int, int);
*/
*/
}
// namespace
vllm
}
// namespace
sglang
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
View file @
d25398cb
...
@@ -29,8 +29,8 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
...
@@ -29,8 +29,8 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
for (int i = 0; i < world_size; i++) {
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
}
return (fptr_t) new
vllm
::CustomAllreduce(
return (fptr_t) new
sglang
::CustomAllreduce(
reinterpret_cast<
vllm
::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
reinterpret_cast<
sglang
::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
}
...
@@ -58,7 +58,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
...
@@ -58,7 +58,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
hipStream_t stream) {
auto fa = reinterpret_cast<
vllm
::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<
sglang
::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
switch (out.scalar_type()) {
case at::ScalarType::Float: {
case at::ScalarType::Float: {
...
@@ -110,22 +110,22 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
...
@@ -110,22 +110,22 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
}
}
void dispose(fptr_t _fa) {
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<
vllm
::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<
sglang
::CustomAllreduce*>(_fa);
delete fa;
delete fa;
}
}
int64_t meta_size() { return sizeof(
vllm
::Signal); }
int64_t meta_size() { return sizeof(
sglang
::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<
vllm
::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<
sglang
::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
fa->register_buffer(handles, offsets, t.data_ptr());
}
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
fptr_t _fa) {
auto fa = reinterpret_cast<
vllm
::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<
sglang
::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
...
@@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
...
@@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<
vllm
::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<
sglang
::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
fa->register_graph_buffers(handles, offsets);
}
}
...
...
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
View file @
d25398cb
...
@@ -26,7 +26,7 @@ typedef __hip_bfloat16 nv_bfloat16;
...
@@ -26,7 +26,7 @@ typedef __hip_bfloat16 nv_bfloat16;
} \
} \
} while (0)
} while (0)
namespace
vllm
{
namespace
sglang
{
constexpr
int
kMaxBlocks
=
64
;
constexpr
int
kMaxBlocks
=
64
;
// note: we don't want to use atomics for signals because peer atomics are no
// note: we don't want to use atomics for signals because peer atomics are no
...
@@ -572,11 +572,11 @@ class CustomAllreduce {
...
@@ -572,11 +572,11 @@ class CustomAllreduce {
CUDACHECK
(
hipIpcCloseMemHandle
(
ptr
));
CUDACHECK
(
hipIpcCloseMemHandle
(
ptr
));
}
}
}
}
};
// namespace
vllm
};
// namespace
sglang
/**
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
a template instantiation:
* template void
vllm
::CustomAllreduce::allreduce<half>(hipStream_t, half *,
* template void
sglang
::CustomAllreduce::allreduce<half>(hipStream_t, half *,
half *, int, int, int);
half *, int, int, int);
*/
*/
}
// namespace
vllm
}
// namespace
sglang
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment