Unverified Commit 6005ecee authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] remove process_group from inputs of shm_allreduce and shm_allgather (#7486)

parent ff2e9c94
...@@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) { ...@@ -47,71 +47,26 @@ void initialize(int64_t size, int64_t rank) {
} }
} }
void shm_allreduce( void shm_allreduce(torch::Tensor& data, int64_t op) {
torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data})); RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported"); TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
auto numel = data.numel(); auto numel = data.numel();
int data_size = numel * data.element_size();
int data_size = 0; all_reduce_outer_loop(data, numel, data_size);
bool data_type_fallback = false;
switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<torch::Tensor> tensors = {data};
process_group->allreduce(tensors)->wait();
} else {
all_reduce_outer_loop(data, numel, data_size);
}
return; return;
} }
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim) { torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data})); RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
auto numel = data.numel(); auto numel = data.numel();
int data_size = numel * data.element_size();
int data_size = 0;
bool data_type_fallback = false;
switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
if (dim < 0) { if (dim < 0) {
dim += data.dim(); dim += data.dim();
} }
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<std::vector<torch::Tensor>> output_tensors(1);
auto world_size = process_group->getSize();
for (int i = 0; i < world_size; i++) {
output_tensors[0].push_back(torch::empty_like(data));
}
std::vector<torch::Tensor> input_tensors = {data};
process_group->allgather(output_tensors, input_tensors)->wait();
return torch::cat(output_tensors[0], dim).contiguous();
}
std::vector<int64_t> result_shape = data.sizes().vec(); std::vector<int64_t> result_shape = data.sizes().vec();
result_shape[dim] *= world_size; result_shape[dim] *= world_size;
torch::Tensor result_tensor = torch::empty(result_shape, data.options()); torch::Tensor result_tensor = torch::empty(result_shape, data.options());
......
...@@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight( ...@@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
void initialize(int64_t size, int64_t rank); void initialize(int64_t size, int64_t rank);
// shared mmeory all_reduce // shared mmeory all_reduce
void shm_allreduce( void shm_allreduce(at::Tensor& data, int64_t op);
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
// shared memory all_gather // shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim); at::Tensor shm_allgather(at::Tensor& data, int64_t dim);
// rope // rope
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu( std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
...@@ -344,11 +343,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -344,11 +343,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce // all reduce
m.def("initialize(int size, int rank) -> ()"); m.def("initialize(int size, int rank) -> ()");
m.impl("initialize", torch::kCPU, &initialize); m.impl("initialize", torch::kCPU, &initialize);
m.def( m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce); m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor"); m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather); m.impl("shm_allgather", torch::kCPU, &shm_allgather);
// rope // rope
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment