Unverified Commit 90805ff4 authored by Ma Jian's avatar Ma Jian Committed by GitHub
Browse files

[CI/Build] CPU release supports both of AVX2 and AVX512 (#35466)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Co-authored-by: default avatarjiang1.li <jiang1.li@intel.com>
parent 2562e027
...@@ -13,28 +13,16 @@ endif() ...@@ -13,28 +13,16 @@ endif()
# #
# Define environment variables for special configurations # Define environment variables for special configurations
# #
set(ENABLE_AVX2 $ENV{VLLM_CPU_AVX2}) set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86})
set(ENABLE_AVX512 $ENV{VLLM_CPU_AVX512})
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16}) set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc") include_directories("${CMAKE_SOURCE_DIR}/csrc")
set (ENABLE_NUMA TRUE) set (ENABLE_NUMA TRUE)
# #
# Check the compile flags # Check the compile flags
# #
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
list(APPEND CXX_COMPILE_FLAGS
"-mf16c"
)
endif()
if(MACOSX_FOUND) if(MACOSX_FOUND)
list(APPEND CXX_COMPILE_FLAGS list(APPEND CXX_COMPILE_FLAGS
"-DVLLM_CPU_EXTENSION") "-DVLLM_CPU_EXTENSION")
...@@ -78,18 +66,6 @@ function(check_sysctl TARGET OUT) ...@@ -78,18 +66,6 @@ function(check_sysctl TARGET OUT)
endif() endif()
endfunction() endfunction()
function (is_avx512_disabled OUT)
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
set(${OUT} ON PARENT_SCOPE)
else()
set(${OUT} OFF PARENT_SCOPE)
endif()
endfunction()
is_avx512_disabled(AVX512_DISABLED)
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Apple Silicon Detected") message(STATUS "Apple Silicon Detected")
set(APPLE_SILICON_FOUND TRUE) set(APPLE_SILICON_FOUND TRUE)
...@@ -97,8 +73,6 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ...@@ -97,8 +73,6 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
check_sysctl(hw.optional.neon ASIMD_FOUND) check_sysctl(hw.optional.neon ASIMD_FOUND)
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
else() else()
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
find_isa(${CPUINFO} "Power11" POWER11_FOUND) find_isa(${CPUINFO} "Power11" POWER11_FOUND)
find_isa(${CPUINFO} "POWER10" POWER10_FOUND) find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "POWER9" POWER9_FOUND) find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
...@@ -108,77 +82,32 @@ else() ...@@ -108,77 +82,32 @@ else()
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support
# Support cross-compilation by allowing override via environment variables # Support cross-compilation by allowing override via environment variables
if (ENABLE_AVX2)
set(AVX2_FOUND ON)
message(STATUS "AVX2 support enabled via VLLM_CPU_AVX2 environment variable")
endif()
if (ENABLE_AVX512)
set(AVX512_FOUND ON)
message(STATUS "AVX512 support enabled via VLLM_CPU_AVX512 environment variable")
endif()
if (ENABLE_ARM_BF16) if (ENABLE_ARM_BF16)
set(ARM_BF16_FOUND ON) set(ARM_BF16_FOUND ON)
message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable") message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable")
endif() endif()
endif() endif()
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA)
list(APPEND CXX_COMPILE_FLAGS set(ENABLE_X86_ISA ON)
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3))
message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3")
endif()
list(APPEND CXX_COMPILE_FLAGS "-mf16c")
list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS})
list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS})
list(APPEND CXX_COMPILE_FLAGS_AVX512
"-mavx512f" "-mavx512f"
"-mavx512vl" "-mavx512vl"
"-mavx512bw" "-mavx512bw"
"-mavx512dq") "-mavx512dq"
"-mavx512bf16"
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) "-mavx512vnni"
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) "-mamx-bf16"
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "-mamx-tile")
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS_AVX2
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") "-mavx2")
set(ENABLE_AVX512BF16 ON)
else()
set(ENABLE_AVX512BF16 OFF)
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AVX512BF16 OFF)
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
set(ENABLE_AVX512VNNI ON)
else()
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")
if (POWER9_FOUND) if (POWER9_FOUND)
...@@ -219,12 +148,12 @@ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64") ...@@ -219,12 +148,12 @@ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc") list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
endif() endif()
else() else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
endif() endif()
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms) # Build oneDNN for GEMM kernels
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "") set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
...@@ -329,13 +258,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -329,13 +258,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF") set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF") set(ONEDNN_ENABLE_JIT_PROFILING "ON")
set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "ON")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON")
set(ONEDNN_VERBOSE "OFF") set(ONEDNN_VERBOSE "ON")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
# TODO: Refactor this
if (ENABLE_X86_ISA)
# Note: only enable oneDNN for AVX512
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512})
else()
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS})
endif()
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
...@@ -348,14 +285,20 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -348,14 +285,20 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
PRIVATE ${oneDNN_SOURCE_DIR}/src PRIVATE ${oneDNN_SOURCE_DIR}/src
) )
target_link_libraries(dnnl_ext dnnl torch) target_link_libraries(dnnl_ext dnnl torch)
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC)
list(APPEND LIBS dnnl_ext) list(APPEND LIBS dnnl_ext)
set(USE_ONEDNN ON) set(USE_ONEDNN ON)
else() else()
set(USE_ONEDNN OFF) set(USE_ONEDNN OFF)
endif() endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") # TODO: Refactor this
if (ENABLE_X86_ISA)
message(STATUS "CPU extension (AVX512) compile flags: ${CXX_COMPILE_FLAGS_AVX512}")
message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}")
else()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
endif()
if(ENABLE_NUMA) if(ENABLE_NUMA)
list(APPEND LIBS numa) list(APPEND LIBS numa)
...@@ -390,44 +333,86 @@ set(VLLM_EXT_SRC ...@@ -390,44 +333,86 @@ set(VLLM_EXT_SRC
"csrc/cpu/cpu_attn.cpp" "csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp") "csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) endif()
if(USE_ONEDNN)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/dnnl_kernels.cpp"
${VLLM_EXT_SRC})
endif()
if (ENABLE_X86_ISA)
set(VLLM_EXT_SRC_AVX512
"csrc/cpu/sgl-kernels/gemm.cpp" "csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp" "csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp" "csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/moe.cpp" "csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp" "csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp" "csrc/cpu/sgl-kernels/moe_fp8.cpp"
${VLLM_EXT_SRC})
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
endif()
endif()
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
${VLLM_EXT_SRC}) "csrc/cpu/cpu_wna16.cpp"
endif() "csrc/cpu/cpu_fused_moe.cpp"
"csrc/cpu/utils.cpp"
if(USE_ONEDNN) "csrc/cpu/cpu_attn.cpp"
set(VLLM_EXT_SRC
"csrc/cpu/dnnl_kernels.cpp" "csrc/cpu/dnnl_kernels.cpp"
${VLLM_EXT_SRC}) "csrc/cpu/torch_bindings.cpp"
endif() # TODO: Remove these files
"csrc/cpu/activation.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
set(VLLM_EXT_SRC_AVX2
"csrc/cpu/utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp"
# TODO: Remove these files
"csrc/cpu/activation.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") message(STATUS "CPU extension (AVX512) source files: ${VLLM_EXT_SRC_AVX512}")
message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}")
# define_extension_target(
# Define extension targets _C
# DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX512}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}
USE_SABI 3
WITH_SOABI
)
# For SGL kernels
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AVX512")
# For AMX kernels
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16")
define_extension_target( define_extension_target(
_C_AVX2
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX2}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2}
USE_SABI 3
WITH_SOABI
)
else()
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
#
# Define extension targets
#
define_extension_target(
_C _C
DESTINATION vllm DESTINATION vllm
LANGUAGE CXX LANGUAGE CXX
...@@ -436,6 +421,7 @@ define_extension_target( ...@@ -436,6 +421,7 @@ define_extension_target(
COMPILE_FLAGS ${CXX_COMPILE_FLAGS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3 USE_SABI 3
WITH_SOABI WITH_SOABI
) )
endif()
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#include <torch/library.h> #include <torch/library.h>
// Note: overwrite the external defination for sharing same name between
// libraries use different ISAs.
#define TORCH_EXTENSION_NAME _C
std::string init_cpu_threads_env(const std::string& cpu_ids); std::string init_cpu_threads_env(const std::string& cpu_ids);
void release_dnnl_matmul_handler(int64_t handler); void release_dnnl_matmul_handler(int64_t handler);
...@@ -324,19 +328,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -324,19 +328,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"str act, str isa) -> ()"); "str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe); ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif #endif
} ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
ops.def(
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
cpu_ops.def(
"mla_decode_kvcache(" "mla_decode_kvcache("
" Tensor! out, Tensor query, Tensor kv_cache," " Tensor! out, Tensor query, Tensor kv_cache,"
" float scale, Tensor block_tables, Tensor seq_lens) -> ()"); " float scale, Tensor block_tables, Tensor seq_lens) -> ()");
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache); ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
} }
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
...@@ -818,7 +818,7 @@ def _is_xpu() -> bool: ...@@ -818,7 +818,7 @@ def _is_xpu() -> bool:
def _build_custom_ops() -> bool: def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu() return _is_cuda() or _is_hip()
def get_rocm_version(): def get_rocm_version():
...@@ -987,6 +987,15 @@ if _is_cuda(): ...@@ -987,6 +987,15 @@ if _is_cuda():
CMakeExtension(name="vllm._flashmla_extension_C", optional=True) CMakeExtension(name="vllm._flashmla_extension_C", optional=True)
) )
if _is_cpu():
import platform
if platform.machine() in ("x86_64", "AMD64"):
ext_modules.append(CMakeExtension(name="vllm._C"))
ext_modules.append(CMakeExtension(name="vllm._C_AVX2"))
else:
ext_modules.append(CMakeExtension(name="vllm._C"))
if _build_custom_ops(): if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))
......
...@@ -178,9 +178,7 @@ def mla_decode_kvcache_cpu( ...@@ -178,9 +178,7 @@ def mla_decode_kvcache_cpu(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
) -> None: ) -> None:
torch.ops._C_cpu.mla_decode_kvcache( torch.ops._C.mla_decode_kvcache(out, query, kv_cache, scale, block_tables, seq_lens)
out, query, kv_cache, scale, block_tables, seq_lens
)
# merge attn states ops # merge attn states ops
......
...@@ -483,3 +483,27 @@ class CpuPlatform(Platform): ...@@ -483,3 +483,27 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def support_hybrid_kv_cache(cls) -> bool: def support_hybrid_kv_cache(cls) -> bool:
return True return True
@classmethod
def import_kernels(cls) -> None:
if Platform.get_cpu_architecture() in (CpuArchEnum.X86,):
if torch._C._cpu._is_avx512_supported():
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
else:
# Note: The lib name is _C_AVX2, but the module name is _C.
# This will cause a exception "dynamic module does define
# module export function". But the library is imported
# successfully. So ignore the exception for now, until we find
# a solution.
try:
import vllm._C_AVX2 # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C_AVX2: %r", e)
else:
try:
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
...@@ -85,7 +85,7 @@ class CPUWorker(Worker): ...@@ -85,7 +85,7 @@ class CPUWorker(Worker):
self.local_omp_cpuid = omp_cpuids_list[self.rank] self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind": if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C.init_cpu_threads_env(self.local_omp_cpuid)
if ret: if ret:
logger.info(ret) logger.info(ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment