Commit deeb9cb8 authored by zhangshao's avatar zhangshao
Browse files

pa_v1用原始代码pa_v2用新代码

parent c4b56490
This diff is collapsed.
...@@ -84,22 +84,23 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -84,22 +84,23 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
// Q*K^T operation. //bf16 // Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0> // template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
// inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_v1(const Vec (&q)[N], const Vec (&k)[N]) {
// using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]); A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
// #pragma unroll #pragma unroll
// for (int ii = 1; ii < N; ++ii) { for (int ii = 1; ii < N; ++ii) {
// qk_vec = fma(q[ii], k[ii], qk_vec); qk_vec = fma(q[ii], k[ii], qk_vec);
// } }
// float qk = sum(qk_vec); float qk = sum(qk_vec);
// // Finalize the reduction across lanes. // Finalize the reduction across lanes.
// #pragma unroll #pragma unroll
// for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
// qk += VLLM_SHFL_XOR_SYNC(qk, mask); qk += VLLM_SHFL_XOR_SYNC(qk, mask);
// } }
// return qk; return qk;
// } }
template <typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
...@@ -108,6 +109,10 @@ struct Qk_dot { ...@@ -108,6 +109,10 @@ struct Qk_dot {
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
template <typename Vec, int N>
static inline __device__ float dot_v1(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_v1<THREAD_GROUP_SIZE>(q, k);
}
}; };
} // namespace vllm } // namespace vllm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment