Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1be9a629
Commit
1be9a629
authored
Jul 22, 2024
by
zhangshao
Browse files
pa优化,编译选项优化
parent
d4c0015a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
623 additions
and
313 deletions
+623
-313
CMakeLists.txt
CMakeLists.txt
+8
-5
cmake/utils.cmake
cmake/utils.cmake
+5
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+418
-301
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+98
-6
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+94
-0
No files found.
CMakeLists.txt
View file @
1be9a629
...
...
@@ -4,11 +4,13 @@ project(vllm_extensions LANGUAGES CXX)
option
(
VLLM_TARGET_DEVICE
"Target device backend for vLLM"
"cuda"
)
set
(
CMAKE_BUILD_TYPE
"Release"
)
message
(
STATUS
"Build type:
${
CMAKE_BUILD_TYPE
}
"
)
message
(
STATUS
"Target device:
${
VLLM_TARGET_DEVICE
}
"
)
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/utils.cmake
)
add_compile_options
(
-w
)
#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
...
...
@@ -120,10 +122,11 @@ endif()
# the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
#
override_gpu_arches
(
VLLM_GPU_ARCHES
${
VLLM_GPU_LANG
}
"
${${
VLLM_GPU_LANG
}
_SUPPORTED_ARCHS
}
"
)
#override_gpu_arches(VLLM_GPU_ARCHES
# ${VLLM_GPU_LANG}
# "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
set
(
VLLM_GPU_ARCHES
"gfx928"
)
message
(
STATUS
"
${
VLLM_GPU_ARCHES
}
"
)
#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
...
...
cmake/utils.cmake
View file @
1be9a629
...
...
@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags"
)
list
(
REMOVE_ITEM GPU_FLAGS
"-DUSE_ROCM=1"
)
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
# "-DENABLE_FP8"
...
...
@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
"--gpu-max-threads-per-block=1024"
)
message
(
STATUS
"
${
GPU_FLAGS
}
"
)
endif
()
set
(
${
OUT_GPU_FLAGS
}
${
GPU_FLAGS
}
PARENT_SCOPE
)
endfunction
()
...
...
csrc/attention/attention_kernels.cu
View file @
1be9a629
This diff is collapsed.
Click to expand it.
csrc/attention/attention_utils.cuh
View file @
1be9a629
...
...
@@ -26,19 +26,106 @@
namespace
vllm
{
// Q*K^T operation.
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint32_t
&
b
,
const
uint32_t
&
c
)
{
asm
volatile
(
"v_dot2_f32_f16 %0, %1, %2, %0;"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
inline
__device__
void
v_pk_fma_f16
(
uint32_t
&
a
,
const
uint32_t
&
b
,
const
uint32_t
&
c
){
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3;"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"v"
(
a
));
}
inline
__device__
void
ds_read_b128
(
uint4
&
a
,
uint32_t
offset
){
asm
volatile
(
"ds_read_b128 %0 %1;"
:
"=v"
(
a
)
:
"v"
(
offset
));
}
inline
__device__
void
ds_read_b128_sync
(
uint4
&
a
,
uint32_t
offset
){
asm
volatile
(
"ds_read_b128 %0 %1
\n
s_waitcnt lgkmcnt(1);"
:
"=v"
(
a
)
:
"v"
(
offset
));
}
inline
__device__
void
lgkmcnt0
(){
asm
volatile
(
"s_waitcnt lgkmcnt(0);"
);
}
__device__
inline
size_t
__nv_cvta_generic_to_shared_impl
(
const
void
*
__ptr
)
{
return
(
size_t
)(
void
__attribute__
((
address_space
(
3
)))
*
)
__ptr
;
}
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint2
&
b
,
const
uint2
&
c
)
{
v_dot2_f32_f16
(
a
,
b
.
x
,
c
.
x
);
v_dot2_f32_f16
(
a
,
b
.
y
,
c
.
y
);
}
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint4
&
b
,
const
uint4
&
c
)
{
v_dot2_f32_f16
(
a
,
b
.
x
,
c
.
x
);
v_dot2_f32_f16
(
a
,
b
.
y
,
c
.
y
);
v_dot2_f32_f16
(
a
,
b
.
z
,
c
.
z
);
v_dot2_f32_f16
(
a
,
b
.
w
,
c
.
w
);
}
inline
__device__
float
add_half2
(
uint32_t
a
){
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u32
=
a
;
return
static_cast
<
float
>
(
tmp
.
u16
[
0
]
+
tmp
.
u16
[
1
]);
}
inline
__device__
void
v_pk_fma_f16x8
(
float
&
a
,
const
uint4
&
b
,
const
uint4
&
c
)
{
uint32_t
tmp
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
b
.
x
,
c
.
x
);
v_pk_fma_f16
(
tmp
,
b
.
y
,
c
.
y
);
v_pk_fma_f16
(
tmp
,
b
.
z
,
c
.
z
);
v_pk_fma_f16
(
tmp
,
b
.
w
,
c
.
w
);
a
+=
add_half2
(
tmp
);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
float
qk
=
0
;
// uint32_t offset = __nv_cvta_generic_to_shared_impl(q);
// const uint4 *k_ptr= reinterpret_cast<const uint4 *>(k);
// // Compute the parallel products for Q*K^T (treat vector lanes separately).
// constexpr int loop=N*sizeof(Vec)/16/2;
// uint4 qt[2];
// #pragma unroll
// for (int ii = 0; ii < loop; ++ii) {
// ds_read_b128(qt[0],offset+16*ii*2);
// ds_read_b128_sync(qt[1],offset+16*(ii*2+1));
// v_dot2_f32_f16(qk,qt[0],k_ptr[ii*2]);
// // v_pk_fma_f16x8(qk,qt[0],k_ptr[ii*2]);
// lgkmcnt0();
// v_dot2_f32_f16(qk,qt[1],k_ptr[ii*2+1]);
// // v_pk_fma_f16x8(qk,qt[1],k_ptr[ii*2+1]);
// }
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
v_dot2_f32_f16
(
qk
,
q
[
ii
],
k
[
ii
]);
}
// Finalize the reduction across lanes.
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
}
return
qk
;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_vpack_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
// Finalize the reduction across lanes.
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
...
...
@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return
qk
;
}
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
struct
Qk_dot
{
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
// return qk_dot_vpack_<THREAD_GROUP_SIZE>(q, k);
// }
};
}
// namespace vllm
\ No newline at end of file
}
// namespace vllm
csrc/attention/static_switch.h
0 → 100644
View file @
1be9a629
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
// #define HEADSIZE_SWITCH(HEADDIM, ...) \
// [&] { \
// if (HEADDIM == 64) { \
// constexpr static int HEAD_SIZE = 64; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 80) { \
// constexpr static int HEAD_SIZE = 80; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 96) { \
// constexpr static int HEAD_SIZE = 96; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 112) { \
// constexpr static int HEAD_SIZE = 112; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 128) { \
// constexpr static int HEAD_SIZE = 128; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 256) { \
// constexpr static int HEAD_SIZE = 256; \
// return __VA_ARGS__(); \
// } \
// else { \
// TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
// } \
// }()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment