Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
deeb9cb8
Commit
deeb9cb8
authored
Jul 10, 2024
by
zhangshao
Browse files
pa_v1用原始代码pa_v2用新代码
parent
c4b56490
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
438 additions
and
20 deletions
+438
-20
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+418
-5
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+20
-15
No files found.
csrc/attention/attention_kernels.cu
View file @
deeb9cb8
This diff is collapsed.
Click to expand it.
csrc/attention/attention_utils.cuh
View file @
deeb9cb8
...
@@ -84,22 +84,23 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
...
@@ -84,22 +84,23 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
// Q*K^T operation. //bf16
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
// inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_v1
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
// using A_vec = typename FloatVec<Vec>::Type;
// A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
// #pragma unroll
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
// for (int ii = 1; ii < N; ++ii) {
#pragma unroll
// qk_vec = fma(q[ii], k[ii], qk_vec);
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
// }
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
// float qk = sum(qk_vec);
}
// // Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
// #pragma unroll
// Finalize the reduction across lanes.
// for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
// qk += VLLM_SHFL_XOR_SYNC(qk, mask);
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
// }
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
// return qk;
}
// }
return
qk
;
}
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
...
@@ -108,6 +109,10 @@ struct Qk_dot {
...
@@ -108,6 +109,10 @@ struct Qk_dot {
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
}
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot_v1
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_v1
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
};
};
}
// namespace vllm
}
// namespace vllm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment