Commit 112cb3c9 authored by chenxl's avatar chenxl
Browse files

[feature] support python 310 and multi instruction

parent 25620829
__version__ = "0.1.0" __version__ = "0.1.1"
\ No newline at end of file \ No newline at end of file
...@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16) ...@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16)
project(cpuinfer_ext VERSION 0.1.0) project(cpuinfer_ext VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -ffast-math")
set(CMAKE_BUILD_TYPE "Release") set(CMAKE_BUILD_TYPE "Release")
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
...@@ -10,6 +10,27 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) ...@@ -10,6 +10,27 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
option(LLAMA_NATIVE "llama: enable -march=native flag" ON) option(LLAMA_NATIVE "llama: enable -march=native flag" ON)
# instruction set specific
if (LLAMA_NATIVE)
set(INS_ENB OFF)
else()
set(INS_ENB ON)
endif()
option(LLAMA_AVX "llama: enable AVX" OFF)
option(LLAMA_AVX2 "llama: enable AVX2" OFF)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_FMA "llama: enable FMA" OFF)
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF)
endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue # feel free to update the Makefile for your architecture and send a pull request or issue
...@@ -102,6 +123,20 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW ...@@ -102,6 +123,20 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>) add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>) add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif() endif()
if (LLAMA_AVX512_FANCY_SIMD)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
if (LLAMA_AVX512_BF16)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
endif()
elseif (LLAMA_AVX2) elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2) list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX) elseif (LLAMA_AVX)
...@@ -133,6 +168,15 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW ...@@ -133,6 +168,15 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_VNNI) if (LLAMA_AVX512_VNNI)
list(APPEND ARCH_FLAGS -mavx512vnni) list(APPEND ARCH_FLAGS -mavx512vnni)
endif() endif()
if (LLAMA_AVX512_FANCY_SIMD)
list(APPEND ARCH_FLAGS -mavx512vl)
list(APPEND ARCH_FLAGS -mavx512bw)
list(APPEND ARCH_FLAGS -mavx512dq)
list(APPEND ARCH_FLAGS -mavx512vnni)
endif()
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)
endif()
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")
......
include(CheckCSourceRuns)
set(AVX_CODE "
#include <immintrin.h>
int main()
{
__m256 a;
a = _mm256_set1_ps(0);
return 0;
}
")
set(AVX512_CODE "
#include <immintrin.h>
int main()
{
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0);
__m512i b = a;
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
return 0;
}
")
set(AVX2_CODE "
#include <immintrin.h>
int main()
{
__m256i a = {0};
a = _mm256_abs_epi16(a);
__m256i x;
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
return 0;
}
")
set(FMA_CODE "
#include <immintrin.h>
int main()
{
__m256 acc = _mm256_setzero_ps();
const __m256 d = _mm256_setzero_ps();
const __m256 p = _mm256_setzero_ps();
acc = _mm256_fmadd_ps( d, p, acc );
return 0;
}
")
macro(check_sse type flags)
set(__FLAG_I 1)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
foreach (__FLAG ${flags})
if (NOT ${type}_FOUND)
set(CMAKE_REQUIRED_FLAGS ${__FLAG})
check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
if (HAS_${type}_${__FLAG_I})
set(${type}_FOUND TRUE CACHE BOOL "${type} support")
set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
endif()
math(EXPR __FLAG_I "${__FLAG_I}+1")
endif()
endforeach()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
if (NOT ${type}_FOUND)
set(${type}_FOUND FALSE CACHE BOOL "${type} support")
set(${type}_FLAGS "" CACHE STRING "${type} flags")
endif()
mark_as_advanced(${type}_FOUND ${type}_FLAGS)
endmacro()
# flags are for MSVC only!
check_sse("AVX" " ;/arch:AVX")
if (NOT ${AVX_FOUND})
set(LLAMA_AVX OFF)
else()
set(LLAMA_AVX ON)
endif()
check_sse("AVX2" " ;/arch:AVX2")
check_sse("FMA" " ;/arch:AVX2")
if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
set(LLAMA_AVX2 OFF)
else()
set(LLAMA_AVX2 ON)
endif()
check_sse("AVX512" " ;/arch:AVX512")
if (NOT ${AVX512_FOUND})
set(LLAMA_AVX512 OFF)
else()
set(LLAMA_AVX512 ON)
endif()
[build-system] [build-system]
requires = [ requires = [
"setuptools", "setuptools",
"torch == 2.3.1", "torch >= 2.3.0",
"ninja", "ninja",
"packaging" "packaging"
] ]
...@@ -29,7 +29,7 @@ dependencies = [ ...@@ -29,7 +29,7 @@ dependencies = [
"fire" "fire"
] ]
requires-python = ">=3.11" requires-python = ">=3.10"
authors = [ authors = [
{name = "KVCache.AI", email = "zhang.mingxing@outlook.com"} {name = "KVCache.AI", email = "zhang.mingxing@outlook.com"}
...@@ -50,6 +50,7 @@ keywords = ["ktransformers", "llm"] ...@@ -50,6 +50,7 @@ keywords = ["ktransformers", "llm"]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12" "Programming Language :: Python :: 3.12"
] ]
......
...@@ -6,7 +6,7 @@ Author : chenxl ...@@ -6,7 +6,7 @@ Author : chenxl
Date : 2024-07-27 16:15:27 Date : 2024-07-27 16:15:27
Version : 1.0.0 Version : 1.0.0
LastEditors : chenxl LastEditors : chenxl
LastEditTime : 2024-07-29 09:40:24 LastEditTime : 2024-07-31 09:44:46
Adapted from: Adapted from:
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
Copyright (c) 2023, Tri Dao. Copyright (c) 2023, Tri Dao.
...@@ -19,6 +19,7 @@ import re ...@@ -19,6 +19,7 @@ import re
import ast import ast
import subprocess import subprocess
import platform import platform
import http.client
import urllib.request import urllib.request
import urllib.error import urllib.error
from pathlib import Path from pathlib import Path
...@@ -28,7 +29,16 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel ...@@ -28,7 +29,16 @@ from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from setuptools import setup, Extension from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
class CpuInstructInfo:
CPU_INSTRUCT = os.getenv("CPU_INSTRUCT", "NATIVE")
FANCY = "FANCY"
AVX512 = "AVX512"
AVX2 = "AVX2"
CMAKE_NATIVE = "-DLLAMA_NATIVE=ON"
CMAKE_FANCY = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON"
CMAKE_AVX512 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON"
CMAKE_AVX2 = "-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON"
class VersionInfo: class VersionInfo:
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "ktransformers" PACKAGE_NAME = "ktransformers"
...@@ -61,12 +71,24 @@ class VersionInfo: ...@@ -61,12 +71,24 @@ class VersionInfo:
raise ValueError("Unsupported platform: {}".format(sys.platform)) raise ValueError("Unsupported platform: {}".format(sys.platform))
def get_cpu_instruct(self,): def get_cpu_instruct(self,):
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
return "fancy"
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
return "avx512"
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:
return "avx2"
else:
print("Using native cpu instruct")
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
with open('/proc/cpuinfo', 'r', encoding="utf-8") as cpu_f: with open('/proc/cpuinfo', 'r', encoding="utf-8") as cpu_f:
cpuinfo = cpu_f.read() cpuinfo = cpu_f.read()
flags_line = [line for line in cpuinfo.split( flags_line = [line for line in cpuinfo.split(
'\n') if line.startswith('flags')][0] '\n') if line.startswith('flags')][0]
flags = flags_line.split(':')[1].strip().split(' ') flags = flags_line.split(':')[1].strip().split(' ')
# fancy with AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI
for flag in flags:
if 'avx512bw' in flag:
return 'fancy'
for flag in flags: for flag in flags:
if 'avx512' in flag: if 'avx512' in flag:
return 'avx512' return 'avx512'
...@@ -116,6 +138,7 @@ class BuildWheelsCommand(_bdist_wheel): ...@@ -116,6 +138,7 @@ class BuildWheelsCommand(_bdist_wheel):
def run(self): def run(self):
if VersionInfo.FORCE_BUILD: if VersionInfo.FORCE_BUILD:
super().run() super().run()
return
wheel_filename, wheel_url = self.get_wheel_name() wheel_filename, wheel_url = self.get_wheel_name()
print("Guessing wheel URL: ", wheel_url) print("Guessing wheel URL: ", wheel_url)
try: try:
...@@ -132,7 +155,7 @@ class BuildWheelsCommand(_bdist_wheel): ...@@ -132,7 +155,7 @@ class BuildWheelsCommand(_bdist_wheel):
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path) print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path) os.rename(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError): except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
print("Precompiled wheel not found. Building from source...") print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source # If the wheel could not be downloaded, build from source
super().run() super().run()
...@@ -186,7 +209,19 @@ class CMakeBuild(BuildExtension): ...@@ -186,7 +209,19 @@ class CMakeBuild(BuildExtension):
if "CMAKE_ARGS" in os.environ: if "CMAKE_ARGS" in os.environ:
cmake_args += [ cmake_args += [
item for item in os.environ["CMAKE_ARGS"].split(" ") if item] item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:
cpu_args = CpuInstructInfo.CMAKE_FANCY
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:
cpu_args = CpuInstructInfo.CMAKE_AVX512
elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:
cpu_args = CpuInstructInfo.CMAKE_AVX2
else:
cpu_args = CpuInstructInfo.CMAKE_NATIVE
cmake_args += [
item for item in cpu_args.split(" ") if item
]
# In this example, we pass in the version to C++. You might not need to. # In this example, we pass in the version to C++. You might not need to.
cmake_args += [ cmake_args += [
f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment