Unverified Commit 8f844db6 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

[CPU] fix all_reduce and all_gather (#6770)


Co-authored-by: default avatarblzheng <beilei.zheng@intel.com>
parent 36cc3ffd
...@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup ...@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
get_int_env_var,
is_cuda_alike, is_cuda_alike,
is_npu, is_npu,
is_shm_available,
supports_custom_op, supports_custom_op,
) )
...@@ -222,6 +224,7 @@ class GroupCoordinator: ...@@ -222,6 +224,7 @@ class GroupCoordinator:
self.local_rank = local_rank self.local_rank = local_rank
self.device_group = None self.device_group = None
self.cpu_group = None self.cpu_group = None
self.local_size = get_int_env_var("LOCAL_SIZE", 0)
for ranks in group_ranks: for ranks in group_ranks:
device_group = torch.distributed.new_group( device_group = torch.distributed.new_group(
...@@ -440,9 +443,12 @@ class GroupCoordinator: ...@@ -440,9 +443,12 @@ class GroupCoordinator:
return input_ return input_
if input_.is_cpu: if input_.is_cpu:
import intel_extension_for_pytorch as ipex if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce(
ipex.distributed.all_reduce(input_, group=self.device_group) input_, torch.distributed.ReduceOp.SUM
)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_ return input_
if not supports_custom_op(): if not supports_custom_op():
...@@ -570,6 +576,16 @@ class GroupCoordinator: ...@@ -570,6 +576,16 @@ class GroupCoordinator:
output_tensor = torch.empty( output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device output_size, dtype=input_.dtype, device=input_.device
) )
if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
else:
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
return output_tensor
# All-gather. # All-gather.
self.all_gather_into_tensor(output_tensor, input_) self.all_gather_into_tensor(output_tensor, input_)
# Reshape # Reshape
......
...@@ -506,9 +506,13 @@ class ModelRunner: ...@@ -506,9 +506,13 @@ class ModelRunner:
if _is_cpu_amx_available: if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores # Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
# Set local size to hint SGLang to use shared memory based AllReduce
os.environ["LOCAL_SIZE"] = str(self.tp_size)
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
else: else:
logger.warning( logger.warning(
"init_cpu_threads_env is skipped since intel amx backend is not available" "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
) )
# Only initialize the distributed environment on the target model worker. # Only initialize the distributed environment on the target model worker.
......
...@@ -2612,3 +2612,12 @@ def get_cpu_ids_by_node(): ...@@ -2612,3 +2612,12 @@ def get_cpu_ids_by_node():
# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
return cpu_ids return cpu_ids
def is_shm_available(dtype, world_size, local_size):
return (
cpu_has_amx_support()
and dtype in [torch.bfloat16, torch.float]
and world_size >= 1
and world_size == local_size
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment