Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
436e523b
Unverified
Commit
436e523b
authored
May 03, 2023
by
Woosuk Kwon
Committed by
GitHub
May 03, 2023
Browse files
Refactor attention kernels (#53)
parent
27f1410d
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1253 additions
and
1673 deletions
+1253
-1673
csrc/attention/attention_dtypes.cuh
csrc/attention/attention_dtypes.cuh
+5
-0
csrc/attention/attention_generic.cuh
csrc/attention/attention_generic.cuh
+47
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+451
-0
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+38
-0
csrc/attention/dtype_float16.cuh
csrc/attention/dtype_float16.cuh
+426
-0
csrc/attention/dtype_float32.cuh
csrc/attention/dtype_float32.cuh
+250
-0
csrc/attention_utils.h
csrc/attention_utils.h
+0
-165
csrc/cuda_primitives.h
csrc/cuda_primitives.h
+0
-1340
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+1
-1
csrc/reduction_utils.cuh
csrc/reduction_utils.cuh
+34
-0
csrc/reduction_utils.h
csrc/reduction_utils.h
+0
-76
setup.py
setup.py
+1
-1
tests/kernels/attention.py
tests/kernels/attention.py
+0
-90
No files found.
csrc/attention/attention_dtypes.cuh
0 → 100644
View file @
436e523b
#pragma once
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
csrc/attention/attention_generic.cuh
0 → 100644
View file @
436e523b
#pragma once
#include <stdint.h>
namespace
cacheflow
{
// A vector type to store Q, K, V elements.
template
<
typename
T
,
int
VEC_SIZE
>
struct
Vec
{};
// A vector type to store FP32 accumulators.
template
<
typename
T
>
struct
FloatVec
{};
// Template vector operations.
template
<
typename
Acc
,
typename
A
,
typename
B
>
inline
__device__
Acc
mul
(
A
a
,
B
b
);
template
<
typename
T
>
inline
__device__
float
sum
(
T
v
);
template
<
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
T
,
T
,
T
>
(
a
,
b
));
}
template
<
typename
A
,
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
A
,
T
,
T
>
(
a
,
b
));
}
template
<
typename
T
>
inline
__device__
void
zero
(
T
&
dst
)
{
constexpr
int
WORDS
=
sizeof
(
T
)
/
4
;
union
{
T
raw
;
uint32_t
words
[
WORDS
];
}
tmp
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
WORDS
;
++
ii
)
{
tmp
.
words
[
ii
]
=
0u
;
}
dst
=
tmp
.
raw
;
}
}
// namespace cacheflow
csrc/attention_kernels.cu
→
csrc/attention
/attention
_kernels.cu
View file @
436e523b
This diff is collapsed.
Click to expand it.
csrc/attention/attention_utils.cuh
0 → 100644
View file @
436e523b
#pragma once
#include "attention_dtypes.cuh"
#include <float.h>
#include <type_traits>
namespace
cacheflow
{
// Q*K^T operation.
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
struct
Qk_dot
{
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
};
}
// namespace cacheflow
csrc/attention/dtype_float16.cuh
0 → 100644
View file @
436e523b
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <stdint.h>
namespace
cacheflow
{
// FP16 vector types for Q, K, V.
template
<
>
struct
Vec
<
uint16_t
,
1
>
{
using
Type
=
uint16_t
;
};
template
<
>
struct
Vec
<
uint16_t
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Vec
<
uint16_t
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Vec
<
uint16_t
,
8
>
{
using
Type
=
uint4
;
};
// FP32 accumulator vector types corresponding to Vec.
template
<
>
struct
FloatVec
<
uint16_t
>
{
using
Type
=
float
;
};
template
<
>
struct
FloatVec
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
FloatVec
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
FloatVec
<
uint4
>
{
using
Type
=
Float8_
;
};
// Utility functions for type conversions.
inline
__device__
uint32_t
h0_h0
(
uint16_t
a
)
{
uint32_t
b
;
asm
volatile
(
"mov.b32 %0, {%1, %1};"
:
"=r"
(
b
)
:
"h"
(
a
));
return
b
;
}
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
inline
__device__
float2
half2_to_float2
(
uint32_t
v
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
v
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
));
return
tmp
.
u16
[
0
];
}
inline
__device__
uint32_t
float2_to_half2
(
float2
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
tmp
.
u32
)
:
"f"
(
f
.
y
),
"f"
(
f
.
x
));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
.
x
));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
1
])
:
"f"
(
f
.
y
));
#endif
return
tmp
.
u32
;
}
// Vector addition.
inline
__device__
uint16_t
add
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"add.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
inline
__device__
uint32_t
add
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
inline
__device__
uint2
add
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
uint4
add
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
float2
add
(
uint32_t
a
,
float2
fb
)
{
float2
fa
=
half2_to_float2
(
a
);
return
add
(
fa
,
fb
);
}
inline
__device__
Float4_
add
(
uint2
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
inline
__device__
Float8_
add
(
uint4
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
// Vector multiplication.
template
<
>
inline
__device__
uint16_t
mul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"mul.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
template
<
>
inline
__device__
uint32_t
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
uint2
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
uint4
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
uint16_t
b
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
;
}
template
<
>
inline
__device__
float2
mul
(
uint32_t
a
,
uint32_t
b
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
template
<
>
inline
__device__
float2
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
float2
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
template
<
>
inline
__device__
Float4_
mul
(
uint2
a
,
uint2
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
template
<
>
inline
__device__
Float4_
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
fc
;
}
template
<
>
inline
__device__
Float8_
mul
(
uint4
a
,
uint4
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
template
<
>
inline
__device__
Float8_
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
fc
;
}
// Vector fused multiply-add.
inline
__device__
uint32_t
fma
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
inline
__device__
uint32_t
fma
(
uint16_t
a
,
uint32_t
b
,
uint32_t
c
)
{
return
fma
(
h0_h0
(
a
),
b
,
c
);
}
inline
__device__
uint2
fma
(
uint2
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
uint2
fma
(
uint16_t
a
,
uint2
b
,
uint2
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
uint4
fma
(
uint4
a
,
uint4
b
,
uint4
c
)
{
uint4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
uint4
fma
(
uint16_t
a
,
uint4
b
,
uint4
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float
fma
(
uint16_t
a
,
uint16_t
b
,
float
fc
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
+
fc
;
}
inline
__device__
float2
fma
(
uint32_t
a
,
uint32_t
b
,
float2
fc
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
inline
__device__
float2
fma
(
uint16_t
a
,
uint32_t
b
,
float2
fc
)
{
return
fma
(
h0_h0
(
a
),
b
,
fc
);
}
inline
__device__
Float4_
fma
(
uint2
a
,
uint2
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
inline
__device__
Float4_
fma
(
uint16_t
a
,
uint2
b
,
Float4_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
inline
__device__
Float8_
fma
(
uint4
a
,
uint4
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
inline
__device__
Float8_
fma
(
uint16_t
a
,
uint4
b
,
Float8_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
// Vector sum.
template
<
>
inline
__device__
float
sum
(
uint16_t
v
)
{
return
half_to_float
(
v
);
}
template
<
>
inline
__device__
float
sum
(
uint32_t
v
)
{
float2
tmp
=
half2_to_float2
(
v
);
return
tmp
.
x
+
tmp
.
y
;
}
template
<
>
inline
__device__
float
sum
(
uint2
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
return
sum
(
c
);
}
template
<
>
inline
__device__
float
sum
(
uint4
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
c
=
add
(
c
,
v
.
z
);
c
=
add
(
c
,
v
.
w
);
return
sum
(
c
);
}
// Zero-out a vector.
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
// From float32 to float16.
inline
__device__
void
from_float
(
uint16_t
&
dst
,
float
src
)
{
dst
=
float_to_half
(
src
);
}
inline
__device__
void
from_float
(
uint32_t
&
dst
,
float2
src
)
{
dst
=
float2_to_half2
(
src
);
}
inline
__device__
void
from_float
(
uint2
&
dst
,
Float4_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
}
inline
__device__
void
from_float
(
uint4
&
dst
,
Float8_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
dst
.
z
=
float2_to_half2
(
src
.
z
);
dst
.
w
=
float2_to_half2
(
src
.
w
);
}
// From float16 to float32.
inline
__device__
float
to_float
(
uint16_t
u
)
{
return
half_to_float
(
u
);
}
inline
__device__
float2
to_float
(
uint32_t
u
)
{
return
half2_to_float2
(
u
);
}
inline
__device__
Float4_
to_float
(
uint2
u
)
{
Float4_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
return
tmp
;
}
inline
__device__
Float8_
to_float
(
uint4
u
)
{
Float8_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
tmp
.
z
=
half2_to_float2
(
u
.
z
);
tmp
.
w
=
half2_to_float2
(
u
.
w
);
return
tmp
;
}
}
// namespace cacheflow
csrc/attention/dtype_float32.cuh
0 → 100644
View file @
436e523b
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
namespace
cacheflow
{
// Define FP32 vector data types.
struct
Float4_
{
float2
x
;
float2
y
;
};
struct
Float8_
{
float2
x
;
float2
y
;
float2
z
;
float2
w
;
};
// FP32 vector types for Q, K, V.
template
<
>
struct
Vec
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
Vec
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
Vec
<
float
,
4
>
{
using
Type
=
float4
;
};
// FP32 accumulator vector types corresponding to Vec.
template
<
>
struct
FloatVec
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
FloatVec
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
FloatVec
<
float4
>
{
using
Type
=
float4
;
};
// Vector addition.
inline
__device__
float
add
(
float
a
,
float
b
)
{
return
a
+
b
;
}
inline
__device__
float2
add
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
float4
add
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
// Vector multiplication.
template
<
>
inline
__device__
float
mul
<
float
,
float
>
(
float
a
,
float
b
)
{
return
a
*
b
;
}
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
return
c
;
}
template
<
>
inline
__device__
float2
mul
(
float
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
return
c
;
}
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
c
.
z
=
a
.
z
*
b
.
z
;
c
.
w
=
a
.
w
*
b
.
w
;
return
c
;
}
template
<
>
inline
__device__
float4
mul
(
float
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
c
.
z
=
a
*
b
.
z
;
c
.
w
=
a
*
b
.
w
;
return
c
;
}
// Vector fused multiply-add.
inline
__device__
float
fma
(
float
a
,
float
b
,
float
c
)
{
return
a
*
b
+
c
;
}
inline
__device__
float2
fma
(
float2
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
float2
fma
(
float
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
float4
fma
(
float4
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float4
fma
(
float
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
Float4_
fma
(
float
a
,
Float4_
b
,
Float4_
c
)
{
Float4_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
Float8_
fma
(
float
a
,
Float8_
b
,
Float8_
c
)
{
Float8_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
// Vector sum.
template
<
>
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
template
<
>
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
template
<
>
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
template
<
>
inline
__device__
float
sum
(
Float4_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
;
}
template
<
>
inline
__device__
float
sum
(
Float8_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
+
v
.
z
.
x
+
v
.
z
.
y
+
v
.
w
.
x
+
v
.
w
.
y
;
}
// Vector dot product.
inline
__device__
float
dot
(
float
a
,
float
b
)
{
return
a
*
b
;
}
inline
__device__
float
dot
(
float2
a
,
float2
b
)
{
float2
c
=
mul
<
float2
,
float2
,
float2
>
(
a
,
b
);
return
c
.
x
+
c
.
y
;
}
inline
__device__
float
dot
(
Float4_
a
,
Float4_
b
)
{
float2
acc
=
mul
<
float2
,
float2
,
float2
>
(
a
.
x
,
b
.
x
);
acc
=
fma
(
a
.
y
,
b
.
y
,
acc
);
return
acc
.
x
+
acc
.
y
;
}
inline
__device__
float
dot
(
Float8_
a
,
Float8_
b
)
{
float2
acc
=
mul
<
float2
,
float2
,
float2
>
(
a
.
x
,
b
.
x
);
acc
=
fma
(
a
.
y
,
b
.
y
,
acc
);
acc
=
fma
(
a
.
z
,
b
.
z
,
acc
);
acc
=
fma
(
a
.
w
,
b
.
w
,
acc
);
return
acc
.
x
+
acc
.
y
;
}
// From float to float.
inline
__device__
void
from_float
(
float
&
dst
,
float
src
)
{
dst
=
src
;
}
inline
__device__
void
from_float
(
float2
&
dst
,
float2
src
)
{
dst
=
src
;
}
inline
__device__
void
from_float
(
float4
&
dst
,
float4
src
)
{
dst
=
src
;
}
// From float to float.
inline
__device__
float
to_float
(
float
u
)
{
return
u
;
}
inline
__device__
float2
to_float
(
float2
u
)
{
return
u
;
}
inline
__device__
float4
to_float
(
float4
u
)
{
return
u
;
}
inline
__device__
Float4_
to_float
(
Float4_
u
)
{
return
u
;
}
inline
__device__
Float8_
to_float
(
Float8_
u
)
{
return
u
;
}
}
// namespace cacheflow
csrc/attention_utils.h
deleted
100644 → 0
View file @
27f1410d
#pragma once
#include "cuda_primitives.h"
#include <float.h>
#include <type_traits>
#define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_FP32_ACUM_FOR_OUT
namespace
cacheflow
{
// A vector type to store Q, K, V elements.
template
<
typename
T
,
int
VEC_SIZE
>
struct
Vec
{};
template
<
>
struct
Vec
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
Vec
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
Vec
<
float
,
4
>
{
using
Type
=
float4
;
};
template
<
>
struct
Vec
<
uint16_t
,
1
>
{
using
Type
=
uint16_t
;
};
template
<
>
struct
Vec
<
uint16_t
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Vec
<
uint16_t
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Vec
<
uint16_t
,
8
>
{
using
Type
=
uint4
;
};
template
<
typename
T
>
struct
FloatVec
{};
template
<
>
struct
FloatVec
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
FloatVec
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
FloatVec
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
FloatVec
<
uint16_t
>
{
using
Type
=
float
;
};
template
<
>
struct
FloatVec
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
FloatVec
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
FloatVec
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
using
K_vec_acum
=
typename
FloatVec
<
K_vec
>::
Type
;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
K_vec
,
K_vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREADS_PER_KEY
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
hmma_fp32
(
const
uint2
&
a
,
uint32_t
b
)
{
float4
c
;
float
zero
=
0.
f
;
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32
\n
"
" {%0, %1, %2, %3},
\n
"
" {%4, %5},
\n
"
" {%6},
\n
"
" {%7, %7, %7, %7};
\n
"
:
"=f"
(
c
.
x
),
"=f"
(
c
.
y
),
"=f"
(
c
.
z
),
"=f"
(
c
.
w
)
:
"r"
(
a
.
x
)
"r"
(
a
.
y
),
"r"
(
b
),
"f"
(
zero
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
float
qk_hmma_dot_
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using
K_vec_acum
=
typename
FloatVec
<
uint32_t
>::
Type
;
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
uint32_t
,
uint32_t
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t
qk_vec_
=
float2_to_half2
(
qk_vec
);
return
hmma_fp32
(
make_uint2
(
qk_vec_
,
0u
),
0x3c003c00u
).
x
;
#else
return
hmma_fp32
(
make_uint2
(
qk_vec
,
0u
),
0x3c003c00u
).
x
;
#endif
#else
return
0.
f
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Qk_dot
<
uint16_t
,
4
>
{
template
<
int
N
>
static
inline
__device__
float
dot
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return
qk_hmma_dot_
(
q
,
k
);
#else
return
qk_dot_
<
4
>
(
q
,
k
);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
}
// namespace cacheflow
#undef MMHA_USE_FP32_ACUM_FOR_FMA
#undef MMHA_USE_FP32_ACUM_FOR_OUT
csrc/cuda_primitives.h
deleted
100644 → 0
View file @
27f1410d
This diff is collapsed.
Click to expand it.
csrc/layernorm_kernels.cu
View file @
436e523b
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "reduction_utils.h"
#include "reduction_utils.
cu
h"
namespace
cacheflow
{
namespace
cacheflow
{
...
...
csrc/reduction_utils.cuh
0 → 100644
View file @
436e523b
#pragma once
namespace
cacheflow
{
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
,
32
);
return
val
;
}
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
warpReduceSum
<
T
>
(
val
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
>
(
val
);
return
val
;
}
}
// namespace cacheflow
csrc/reduction_utils.h
deleted
100644 → 0
View file @
27f1410d
#pragma once
namespace
cacheflow
{
template
<
int
WARPS_PER_BLOCK
,
int
WARP_SIZE
=
32
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The warps compute the final sums.
if
(
lane
<
WARPS_PER_BLOCK
)
{
sum
=
red_smem
[
lane
];
}
// Parallel reduction inside the warp.
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Broadcast to other threads.
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
#define FINAL_MASK 0xffffffff
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
);
return
val
;
}
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
warpReduceSum
<
T
>
(
val
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
>
(
val
);
return
val
;
}
}
// namespace cacheflow
setup.py
View file @
436e523b
...
@@ -18,7 +18,7 @@ ext_modules.append(cache_extension)
...
@@ -18,7 +18,7 @@ ext_modules.append(cache_extension)
# Attention kernels.
# Attention kernels.
attention_extension
=
cpp_extension
.
CUDAExtension
(
attention_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.attention_ops'
,
name
=
'cacheflow.attention_ops'
,
sources
=
[
'csrc/attention.cpp'
,
'csrc/attention_kernels.cu'
],
sources
=
[
'csrc/attention.cpp'
,
'csrc/attention
/attention
_kernels.cu'
],
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
)
)
ext_modules
.
append
(
attention_extension
)
ext_modules
.
append
(
attention_extension
)
...
...
tests/kernels/attention.py
View file @
436e523b
...
@@ -271,78 +271,6 @@ def test_multi_query_kv_attention(
...
@@ -271,78 +271,6 @@ def test_multi_query_kv_attention(
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
def
test_multi_query_cached_kv_attention
(
num_queries
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
query_lens
=
random
.
sample
(
range
(
1
,
MAX_SEQ_LEN
),
num_queries
)
cu_query_lens
=
[
0
]
for
query_len
in
query_lens
:
cu_query_lens
.
append
(
cu_query_lens
[
-
1
]
+
query_len
)
num_total_tokens
=
cu_query_lens
[
-
1
]
qkv
=
torch
.
randn
(
num_total_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
query
,
_
,
_
=
qkv
.
unbind
(
dim
=
1
)
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_block_shape
=
(
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
key_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
value_block_shape
=
(
num_heads
,
head_size
,
block_size
)
value_cache
=
torch
.
randn
(
size
=
(
num_blocks
,
*
value_block_shape
),
dtype
=
dtype
,
device
=
'cuda'
)
cu_query_lens
=
torch
.
tensor
(
cu_query_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
context_lens
=
[
query_len
+
random
.
randint
(
0
,
MAX_SEQ_LEN
-
query_len
)
for
query_len
in
query_lens
]
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
max_num_blocks_per_seq
=
(
max_context_len
+
block_size
-
1
)
//
block_size
block_tables
=
[]
for
_
in
range
(
num_queries
):
block_table
=
[
random
.
randint
(
0
,
num_blocks
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
output
=
torch
.
empty
(
num_total_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'cuda'
)
attention_ops
.
multi_query_cached_kv_attention
(
cu_query_lens
,
output
,
query
,
key_cache
,
value_cache
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
)
ref_output
=
ref_multi_query_cached_kv_attention
(
cu_query_lens
,
query
,
key_cache
,
value_cache
,
block_tables
,
context_lens
,
dtype
,
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-3
,
rtol
=
1e-5
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_attention
(
seed
:
int
)
->
None
:
def
test_attention
(
seed
:
int
)
->
None
:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
...
@@ -364,24 +292,6 @@ def test_attention(seed: int) -> None:
...
@@ -364,24 +292,6 @@ def test_attention(seed: int) -> None:
dtype
=
dtype
,
dtype
=
dtype
,
)
)
# NOTE(siyuan): Same as above. Re-run the test if it fails. Also
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for
dtype
in
[
torch
.
half
,
torch
.
float
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
256
]:
print
(
f
'Testing multi_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
test_multi_query_cached_kv_attention
(
num_queries
=
11
,
num_heads
=
3
,
head_size
=
head_size
,
block_size
=
block_size
,
num_blocks
=
1024
,
dtype
=
dtype
,
)
# NOTE(woosuk): FlashAttention does not support FP32.
# NOTE(woosuk): FlashAttention does not support FP32.
for
dtype
in
[
torch
.
half
]:
for
dtype
in
[
torch
.
half
]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
# NOTE(woosuk): FlashAttention does not support head_size > 128.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment