Commit deeb9cb8 authored by zhangshao's avatar zhangshao
Browse files

pa_v1用原始代码pa_v2用新代码

parent c4b56490
This diff is collapsed.
......@@ -84,22 +84,23 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
// inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_v1(const Vec (&q)[N], const Vec (&k)[N]) {
// using A_vec = typename FloatVec<Vec>::Type;
// A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
// #pragma unroll
// for (int ii = 1; ii < N; ++ii) {
// qk_vec = fma(q[ii], k[ii], qk_vec);
// }
// float qk = sum(qk_vec);
// // Finalize the reduction across lanes.
// #pragma unroll
// for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
// qk += VLLM_SHFL_XOR_SYNC(qk, mask);
// }
// return qk;
// }
using A_vec = typename FloatVec<Vec>::Type;
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
float qk = sum(qk_vec);
// Finalize the reduction across lanes.
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
template <typename T, int THREAD_GROUP_SIZE>
......@@ -108,6 +109,10 @@ struct Qk_dot {
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
template <typename Vec, int N>
static inline __device__ float dot_v1(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_v1<THREAD_GROUP_SIZE>(q, k);
}
};
} // namespace vllm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment