Commit 1be9a629 authored by zhangshao's avatar zhangshao
Browse files

pa优化,编译选项优化

parent d4c0015a
......@@ -4,11 +4,13 @@ project(vllm_extensions LANGUAGES CXX)
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
set(CMAKE_BUILD_TYPE "Release")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
add_compile_options(-w)
#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
......@@ -120,10 +122,11 @@ endif()
# the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
#
override_gpu_arches(VLLM_GPU_ARCHES
${VLLM_GPU_LANG}
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
#override_gpu_arches(VLLM_GPU_ARCHES
# ${VLLM_GPU_LANG}
# "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
set(VLLM_GPU_ARCHES "gfx928")
message(STATUS "${VLLM_GPU_ARCHES}")
#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
......
......@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags")
list(REMOVE_ITEM GPU_FLAGS
"-DUSE_ROCM=1"
)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
# "-DENABLE_FP8"
......@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
"--gpu-max-threads-per-block=1024")
message(STATUS "${GPU_FLAGS}")
endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction()
......
This diff is collapsed.
......@@ -26,19 +26,106 @@
namespace vllm {
// Q*K^T operation.
inline __device__ void v_dot2_f32_f16(float& a, const uint32_t & b,const uint32_t & c) {
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0;": "=v"(a): "v"(b), "v"(c), "0"(a));
}
inline __device__ void v_pk_fma_f16(uint32_t& a, const uint32_t & b,const uint32_t & c){
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;": "=v"(a) : "v"(b), "v"(c), "v"(a));
}
inline __device__ void ds_read_b128(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1;": "=v" (a): "v" (offset));
}
inline __device__ void ds_read_b128_sync(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1\ns_waitcnt lgkmcnt(1);": "=v" (a): "v" (offset));
}
inline __device__ void lgkmcnt0(){
asm volatile("s_waitcnt lgkmcnt(0);");
}
__device__ inline size_t __nv_cvta_generic_to_shared_impl(const void *__ptr) {
return (size_t)(void __attribute__((address_space(3))) *)__ptr;
}
inline __device__ void v_dot2_f32_f16(float& a,const uint2 & b,const uint2 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
}
inline __device__ void v_dot2_f32_f16(float& a,const uint4 & b,const uint4 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
v_dot2_f32_f16(a, b.z, c.z);
v_dot2_f32_f16(a, b.w, c.w);
}
inline __device__ float add_half2(uint32_t a){
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32=a;
return static_cast<float>(tmp.u16[0]+tmp.u16[1]);
}
inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c) {
uint32_t tmp = mul<uint32_t, uint32_t, uint32_t>(b.x,c.x);
v_pk_fma_f16(tmp,b.y,c.y);
v_pk_fma_f16(tmp,b.z,c.z);
v_pk_fma_f16(tmp,b.w,c.w);
a+=add_half2(tmp);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk =0;
// uint32_t offset = __nv_cvta_generic_to_shared_impl(q);
// const uint4 *k_ptr= reinterpret_cast<const uint4 *>(k);
// // Compute the parallel products for Q*K^T (treat vector lanes separately).
// constexpr int loop=N*sizeof(Vec)/16/2;
// uint4 qt[2];
// #pragma unroll
// for (int ii = 0; ii < loop; ++ii) {
// ds_read_b128(qt[0],offset+16*ii*2);
// ds_read_b128_sync(qt[1],offset+16*(ii*2+1));
// v_dot2_f32_f16(qk,qt[0],k_ptr[ii*2]);
// // v_pk_fma_f16x8(qk,qt[0],k_ptr[ii*2]);
// lgkmcnt0();
// v_dot2_f32_f16(qk,qt[1],k_ptr[ii*2+1]);
// // v_pk_fma_f16x8(qk,qt[1],k_ptr[ii*2+1]);
// }
#pragma unroll
for (int ii = 0; ii < N; ++ii) {
v_dot2_f32_f16(qk,q[ii],k[ii]);
}
// Finalize the reduction across lanes.
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
// Finalize the reduction across lanes.
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
......@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk;
}
template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
// return qk_dot_vpack_<THREAD_GROUP_SIZE>(q, k);
// }
};
} // namespace vllm
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
// #define HEADSIZE_SWITCH(HEADDIM, ...) \
// [&] { \
// if (HEADDIM == 64) { \
// constexpr static int HEAD_SIZE = 64; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 80) { \
// constexpr static int HEAD_SIZE = 80; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 96) { \
// constexpr static int HEAD_SIZE = 96; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 112) { \
// constexpr static int HEAD_SIZE = 112; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 128) { \
// constexpr static int HEAD_SIZE = 128; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 256) { \
// constexpr static int HEAD_SIZE = 256; \
// return __VA_ARGS__(); \
// } \
// else { \
// TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
// } \
// }()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment