Unverified Commit 8f844db6 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] fix all_reduce and all_gather (#6770)


Co-authored-by: default avatarblzheng <beilei.zheng@intel.com>
parent 36cc3ffd
......@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_int_env_var,
is_cuda_alike,
is_npu,
is_shm_available,
supports_custom_op,
)
......@@ -222,6 +224,7 @@ class GroupCoordinator:
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
self.local_size = get_int_env_var("LOCAL_SIZE", 0)
for ranks in group_ranks:
device_group = torch.distributed.new_group(
......@@ -440,9 +443,12 @@ class GroupCoordinator:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce(
input_, torch.distributed.ReduceOp.SUM
)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
if not supports_custom_op():
......@@ -570,6 +576,16 @@ class GroupCoordinator:
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
else:
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
return output_tensor
# All-gather.
self.all_gather_into_tensor(output_tensor, input_)
# Reshape
......
......@@ -506,9 +506,13 @@ class ModelRunner:
if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
else:
logger.warning(
"init_cpu_threads_env is skipped since intel amx backend is not available"
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
)
# Only initialize the distributed environment on the target model worker.
......
......@@ -2612,3 +2612,12 @@ def get_cpu_ids_by_node():
# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
return cpu_ids
def is_shm_available(dtype, world_size, local_size):
return (
cpu_has_amx_support()
and dtype in [torch.bfloat16, torch.float]
and world_size >= 1
and world_size == local_size
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment