Unverified Commit 550b2801 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU][Bugfix] Using custom allreduce for CPU backend (#15934)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent cefb9e5a
...@@ -197,6 +197,7 @@ set(VLLM_EXT_SRC ...@@ -197,6 +197,7 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp" "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
endif() endif()
......
...@@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec<FP16Vec16> { ...@@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
__m256i reg; __m256i reg;
// normal load
explicit FP16Vec16(const void* ptr) explicit FP16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
// non-temproal load
explicit FP16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit FP16Vec16(const FP32Vec16&); explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
...@@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec<BF16Vec16> { ...@@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
__m256i reg; __m256i reg;
// normal load
explicit BF16Vec16(const void* ptr) explicit BF16Vec16(const void* ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
// non-temproal load
explicit BF16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit BF16Vec16(const FP32Vec16&); explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
...@@ -313,8 +323,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -313,8 +323,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
// normal load
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
// non-temproal load
explicit FP32Vec16(bool, void* ptr)
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec4& data) explicit FP32Vec16(const FP32Vec4& data)
...@@ -547,6 +562,33 @@ struct INT8Vec16 : public Vec<INT8Vec16> { ...@@ -547,6 +562,33 @@ struct INT8Vec16 : public Vec<INT8Vec16> {
_mm_mask_storeu_epi8(ptr, mask, reg); _mm_mask_storeu_epi8(ptr, mask, reg);
} }
}; };
struct INT8Vec64 : public Vec<INT8Vec64> {
constexpr static int VEC_ELEM_NUM = 64;
union AliasReg {
__m512i reg;
int8_t values[VEC_ELEM_NUM];
};
__m512i reg;
// normal load
explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {}
// non-temproal load
explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {}
void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); }
void save(int8_t* ptr, const int elem_num) const {
constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF;
__mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num));
_mm512_mask_storeu_epi8(ptr, mask, reg);
}
// non-temproal save
void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); }
};
#endif #endif
template <typename T> template <typename T>
...@@ -657,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { ...@@ -657,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); } inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
#ifdef __AVX512F__
inline void non_temporal_save(FP16Vec16& vec, void* ptr) {
_mm256_stream_si256((__m256i*)ptr, vec.reg);
}
inline void non_temporal_save(BF16Vec32& vec, void* ptr) {
_mm512_stream_si512((__m512i*)ptr, vec.reg);
}
inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
_mm256_stream_si256((__m256i*)ptr, vec.reg);
}
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
_mm512_stream_ps((float*)ptr, vec.reg);
}
#endif
inline void mem_barrier() { _mm_mfence(); }
}; // namespace vec_op }; // namespace vec_op
#endif #endif
This diff is collapsed.
...@@ -22,6 +22,26 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, ...@@ -22,6 +22,26 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale, torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens); torch::Tensor& block_tables, torch::Tensor& seq_lens);
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank);
std::string join_shm_manager(int64_t handle, const std::string& name);
void shm_allreduce(int64_t handle, torch::Tensor& data);
void shm_gather(int64_t handle, torch::Tensor& data,
const std::optional<std::vector<torch::Tensor>>& outputs,
int64_t dst);
void shm_all_gather(int64_t handle, const torch::Tensor& data,
torch::Tensor& output);
void shm_send_tensor_list(int64_t handle,
const std::vector<torch::Tensor>& tensor_list,
int64_t dst);
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
...@@ -131,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -131,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? azp, Tensor? bias) -> ()"); " Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif #endif
// SHM CCL
#ifdef __AVX512F__
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
&init_shm_manager);
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
ops.def(
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
"()");
ops.impl("shm_gather", torch::kCPU, &shm_gather);
ops.def(
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
"()");
ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
ops.def(
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
"()");
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
&shm_recv_tensor_list);
#endif
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
......
...@@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { ...@@ -18,7 +18,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
#ifndef VLLM_NUMA_DISABLED #ifndef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) { std::string init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0); TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids; std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size); omp_cpu_ids.reserve(omp_cpu_mask->size);
......
...@@ -272,12 +272,14 @@ $ python examples/offline_inference/basic/basic.py ...@@ -272,12 +272,14 @@ $ python examples/offline_inference/basic/basic.py
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. - Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel. - On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance.
- Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With [TP feature on CPU](gh-pr:6125) merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
```console ```console
VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
``` ```
- Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like [Nginx](#nginxloadbalancer) or HAProxy are recommended. Anyscale Ray project provides the feature on LLM [serving](https://docs.ray.io/en/latest/serve/index.html). Here is the example to setup a scalable LLM serving with [Ray Serve](https://github.com/intel/llm-on-ray/blob/main/docs/setup.inc.md). - For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node.
- Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional import os
from typing import List, Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
...@@ -16,19 +20,120 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -16,19 +20,120 @@ class CpuCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None, device_group: Optional[ProcessGroup] = None,
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
self.ipex_available = False
self.dist_module = torch.distributed self.dist_module = torch.distributed
try:
import intel_extension_for_pytorch as ipex if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
self.ipex_available = True self.dist_module = _CPUSHMDistributed(self)
self.dist_module = ipex.distributed
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass
def all_reduce(self, input_): def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group) self.dist_module.all_reduce(input_, group=self.device_group)
return input_ return input_
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
self.dist_module.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
self.dist_module.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
instance_identifier = os.environ["VLLM_DIST_IDENT"]
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int:
handle = torch.ops._C.init_shm_manager(
self.group_name,
self.communicator.world_size,
self.communicator.rank,
)
torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager(
handle,
self.group_name,
)
torch.distributed.barrier(self.communicator.device_group)
return handle
def all_reduce(self,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_allreduce(self.handle, input)
def gather(self,
input: torch.Tensor,
gather_list: Optional[List[torch.Tensor]],
dst: int = -1,
group: Optional[ProcessGroup] = None) -> None:
# Note: different from the torch gather, here we use local dst rank.
torch.ops._C.shm_gather(self.handle, input, gather_list,
torch.distributed.get_group_rank(group, dst))
def all_gather_into_tensor(self,
output: torch.Tensor,
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A CPU worker class.""" """A CPU worker class."""
import os
from typing import Dict, List, Optional, Set, Tuple, Type from typing import Dict, List, Optional, Set, Tuple, Type
import torch import torch
...@@ -139,6 +140,8 @@ class CPUWorker(LocalOrDistributedWorkerBase): ...@@ -139,6 +140,8 @@ class CPUWorker(LocalOrDistributedWorkerBase):
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
vllm_config.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
...@@ -217,6 +220,10 @@ class CPUWorker(LocalOrDistributedWorkerBase): ...@@ -217,6 +220,10 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret: if ret:
logger.info(ret) logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.init_distributed_environment() self.init_distributed_environment()
# Set random seed. # Set random seed.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment