Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6005ecee
Unverified
Commit
6005ecee
authored
Jul 01, 2025
by
Chunyuan WU
Committed by
GitHub
Jun 30, 2025
Browse files
[CPU] remove process_group from inputs of shm_allreduce and shm_allgather (#7486)
parent
ff2e9c94
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
57 deletions
+9
-57
sgl-kernel/csrc/cpu/interface.cpp
sgl-kernel/csrc/cpu/interface.cpp
+5
-50
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+4
-7
No files found.
sgl-kernel/csrc/cpu/interface.cpp
View file @
6005ecee
...
@@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) {
...
@@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) {
}
}
}
}
void
shm_allreduce
(
void
shm_allreduce
(
torch
::
Tensor
&
data
,
int64_t
op
)
{
torch
::
Tensor
&
data
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
process_group
,
c10
::
intrusive_ptr
<
c10d
::
ReduceOp
>
op
)
{
RECORD_FUNCTION
(
"sgl-kernel::shm_allreduce"
,
std
::
vector
<
c10
::
IValue
>
({
data
}));
RECORD_FUNCTION
(
"sgl-kernel::shm_allreduce"
,
std
::
vector
<
c10
::
IValue
>
({
data
}));
TORCH_CHECK
(
op
==
c10d
::
ReduceOp
::
SUM
,
"Only torch.distributed.ReduceOp.SUM is supported"
);
TORCH_CHECK
(
op
==
c10d
::
ReduceOp
::
SUM
,
"Only torch.distributed.ReduceOp.SUM is supported"
);
auto
numel
=
data
.
numel
();
auto
numel
=
data
.
numel
();
int
data_size
=
numel
*
data
.
element_size
();
int
data_size
=
0
;
all_reduce_outer_loop
(
data
,
numel
,
data_size
);
bool
data_type_fallback
=
false
;
switch
(
data
.
scalar_type
())
{
case
c10
::
ScalarType
::
BFloat16
:
data_size
=
numel
*
2
;
break
;
case
c10
::
ScalarType
::
Float
:
data_size
=
numel
*
4
;
break
;
default:
data_type_fallback
=
true
;
}
if
(
data_type_fallback
||
!
all_ranks_local_p
)
{
// Fallback to torch distributed allreduce
std
::
vector
<
torch
::
Tensor
>
tensors
=
{
data
};
process_group
->
allreduce
(
tensors
)
->
wait
();
}
else
{
all_reduce_outer_loop
(
data
,
numel
,
data_size
);
}
return
;
return
;
}
}
torch
::
Tensor
shm_allgather
(
torch
::
Tensor
&
data
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
process_group
,
int64_t
dim
)
{
torch
::
Tensor
shm_allgather
(
torch
::
Tensor
&
data
,
int64_t
dim
)
{
RECORD_FUNCTION
(
"sgl-kernel::shm_allgather"
,
std
::
vector
<
c10
::
IValue
>
({
data
}));
RECORD_FUNCTION
(
"sgl-kernel::shm_allgather"
,
std
::
vector
<
c10
::
IValue
>
({
data
}));
auto
numel
=
data
.
numel
();
auto
numel
=
data
.
numel
();
int
data_size
=
numel
*
data
.
element_size
();
int
data_size
=
0
;
bool
data_type_fallback
=
false
;
switch
(
data
.
scalar_type
())
{
case
c10
::
ScalarType
::
BFloat16
:
data_size
=
numel
*
2
;
break
;
case
c10
::
ScalarType
::
Float
:
data_size
=
numel
*
4
;
break
;
default:
data_type_fallback
=
true
;
}
if
(
dim
<
0
)
{
if
(
dim
<
0
)
{
dim
+=
data
.
dim
();
dim
+=
data
.
dim
();
}
}
if
(
data_type_fallback
||
!
all_ranks_local_p
)
{
// Fallback to torch distributed allreduce
std
::
vector
<
std
::
vector
<
torch
::
Tensor
>>
output_tensors
(
1
);
auto
world_size
=
process_group
->
getSize
();
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
output_tensors
[
0
].
push_back
(
torch
::
empty_like
(
data
));
}
std
::
vector
<
torch
::
Tensor
>
input_tensors
=
{
data
};
process_group
->
allgather
(
output_tensors
,
input_tensors
)
->
wait
();
return
torch
::
cat
(
output_tensors
[
0
],
dim
).
contiguous
();
}
std
::
vector
<
int64_t
>
result_shape
=
data
.
sizes
().
vec
();
std
::
vector
<
int64_t
>
result_shape
=
data
.
sizes
().
vec
();
result_shape
[
dim
]
*=
world_size
;
result_shape
[
dim
]
*=
world_size
;
torch
::
Tensor
result_tensor
=
torch
::
empty
(
result_shape
,
data
.
options
());
torch
::
Tensor
result_tensor
=
torch
::
empty
(
result_shape
,
data
.
options
());
...
...
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
6005ecee
...
@@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
...
@@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
void
initialize
(
int64_t
size
,
int64_t
rank
);
void
initialize
(
int64_t
size
,
int64_t
rank
);
// shared mmeory all_reduce
// shared mmeory all_reduce
void
shm_allreduce
(
void
shm_allreduce
(
at
::
Tensor
&
data
,
int64_t
op
);
at
::
Tensor
&
data
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
process_group
,
c10
::
intrusive_ptr
<
c10d
::
ReduceOp
>
op
);
// shared memory all_gather
// shared memory all_gather
at
::
Tensor
shm_allgather
(
at
::
Tensor
&
data
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
process_group
,
int64_t
dim
);
at
::
Tensor
shm_allgather
(
at
::
Tensor
&
data
,
int64_t
dim
);
// rope
// rope
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rotary_embedding_cpu
(
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rotary_embedding_cpu
(
...
@@ -344,11 +343,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -344,11 +343,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce
// all reduce
m
.
def
(
"initialize(int size, int rank) -> ()"
);
m
.
def
(
"initialize(int size, int rank) -> ()"
);
m
.
impl
(
"initialize"
,
torch
::
kCPU
,
&
initialize
);
m
.
impl
(
"initialize"
,
torch
::
kCPU
,
&
initialize
);
m
.
def
(
m
.
def
(
"shm_allreduce(Tensor data, int reduce_op) -> ()"
);
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()"
);
m
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
m
.
impl
(
"shm_allreduce"
,
torch
::
kCPU
,
&
shm_allreduce
);
m
.
def
(
"shm_allgather(Tensor data,
__torch__.torch.classes.c10d.ProcessGroup process_group,
int dim) -> Tensor"
);
m
.
def
(
"shm_allgather(Tensor data, int dim) -> Tensor"
);
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
// rope
// rope
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment