Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
bf87484e
Unverified
Commit
bf87484e
authored
Sep 04, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 04, 2023
Browse files
[BugFix] Fix NaN errors in paged attention kernel (#936)
parent
8ce9c50d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
5 deletions
+32
-5
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+12
-0
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+10
-0
csrc/attention/dtype_float16.cuh
csrc/attention/dtype_float16.cuh
+5
-5
csrc/attention/dtype_float32.cuh
csrc/attention/dtype_float32.cuh
+5
-0
No files found.
csrc/attention/attention_kernels.cu
View file @
bf87484e
...
...
@@ -246,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel(
accs
[
i
]
=
0.
f
;
}
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
warp_idx
;
block_idx
<
num_blocks
;
block_idx
+=
NUM_WARPS
)
{
const
int
physical_block_number
=
block_table
[
block_idx
];
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
...
...
@@ -261,6 +263,16 @@ __global__ void single_query_cached_kv_attention_kernel(
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
if
(
block_idx
==
num_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<=
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
context_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
...
...
csrc/attention/dtype_bfloat16.cuh
View file @
bf87484e
...
...
@@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#endif
}
// Zero-out a variable.
inline
__device__
void
zero
(
__nv_bfloat16
&
dst
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst
=
__ushort_as_bfloat16
((
unsigned
short
)
0x0000U
);
#endif
}
}
// namespace vllm
csrc/attention/dtype_float16.cuh
View file @
bf87484e
...
...
@@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
return
sum
(
c
);
}
// Zero-out a vector.
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
// From float32 to float16.
inline
__device__
void
from_float
(
uint16_t
&
dst
,
float
src
)
{
dst
=
float_to_half
(
src
);
...
...
@@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
return
tmp
;
}
// Zero-out a variable.
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
}
// namespace vllm
csrc/attention/dtype_float32.cuh
View file @
bf87484e
...
...
@@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
return
u
;
}
// Zero-out a variable.
inline
__device__
void
zero
(
float
&
dst
)
{
dst
=
0.
f
;
}
}
// namespace vllm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment