Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd93e661
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "50ed6d0a839e357548fb593762730daab7f1dd30"
Commit
bd93e661
authored
Aug 21, 2024
by
zhuwenwen
Browse files
Update refactoring operation
parent
4405f82c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
102 deletions
+44
-102
CMakeLists.txt
CMakeLists.txt
+2
-2
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+12
-12
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+0
-0
csrc/ops.h
csrc/ops.h
+3
-3
csrc/opt/activation_kernels_opt.cu
csrc/opt/activation_kernels_opt.cu
+8
-69
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+14
-14
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-1
No files found.
CMakeLists.txt
View file @
bd93e661
...
@@ -157,10 +157,10 @@ set(VLLM_EXT_SRC
...
@@ -157,10 +157,10 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/
opt
/attention_kernels_opt.cu"
"csrc/
attention
/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
#
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
...
...
csrc/
opt
/attention_kernels_opt.cu
→
csrc/
attention
/attention_kernels_opt.cu
View file @
bd93e661
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <algorithm>
#include "
../
attention/attention_dtypes.h"
#include "attention/attention_dtypes.h"
#include "
../
attention/attention_utils.cuh"
#include "attention/attention_utils.cuh"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -70,7 +70,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
...
@@ -70,7 +70,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
__device__
void
paged_attention_kernel
_opt
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -590,7 +590,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
...
@@ -590,7 +590,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
REUSE_KV_TIMES
=
1
,
int
REUSE_KV_TIMES
=
1
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
bool
odd_nheads
=
false
>
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
_opt
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
...
@@ -608,7 +608,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
...
@@ -608,7 +608,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
...
@@ -625,7 +625,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
...
@@ -625,7 +625,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
REUSE_KV_TIMES
,
int
REUSE_KV_TIMES
,
int
PARTITION_SIZE
,
int
PARTITION_SIZE
,
bool
odd_nheads
=
false
>
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
_opt
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -647,7 +647,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
...
@@ -647,7 +647,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
...
@@ -659,7 +659,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
...
@@ -659,7 +659,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
int
PARTITION_SIZE
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
_opt
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -767,11 +767,11 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
...
@@ -767,11 +767,11 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
((void*)vllm::paged_attention_v1_kernel
_opt
<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel
_opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
, dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
...
@@ -918,7 +918,7 @@ void paged_attention_v1_opt(
...
@@ -918,7 +918,7 @@ void paged_attention_v1_opt(
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel
_opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
, dim3(grid), dim3(block), shared_mem_size, stream, \
...
@@ -928,7 +928,7 @@ void paged_attention_v1_opt(
...
@@ -928,7 +928,7 @@ void paged_attention_v1_opt(
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel
_opt
<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
...
...
csrc/
opt
/static_switch.h
→
csrc/
attention
/static_switch.h
View file @
bd93e661
File moved
csrc/ops.h
View file @
bd93e661
...
@@ -83,11 +83,11 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
...
@@ -83,11 +83,11 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh
_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu
_and_mul_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_
new
_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_
and_mul
_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_
fast
_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_
tanh_and_mul
_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
);
...
...
csrc/opt/activation_kernels_opt.cu
View file @
bd93e661
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <cmath>
#include <cmath>
#include "cuda_compat.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "
../
dispatch_utils.h"
namespace
vllm
{
namespace
vllm
{
...
@@ -25,7 +25,7 @@ __global__ void act_and_mul_kernel(
...
@@ -25,7 +25,7 @@ __global__ void act_and_mul_kernel(
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_
vectorize
1
(
__global__
void
act_and_mul_kernel_
opt
1
(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
const
int
d
)
{
...
@@ -52,7 +52,7 @@ __global__ void act_and_mul_kernel_vectorize1(
...
@@ -52,7 +52,7 @@ __global__ void act_and_mul_kernel_vectorize1(
}
}
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
),
int
VEC
>
__global__
void
act_and_mul_kernel_
vectorize
2
(
__global__
void
act_and_mul_kernel_
opt
2
(
scalar_t
*
__restrict__
out
,
// [..., d]
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
scalar_t
*
__restrict__
input
,
// [..., 2, d]
const
int
d
)
{
const
int
d
)
{
...
@@ -120,23 +120,23 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
...
@@ -120,23 +120,23 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
input.scalar_type(), "act_and_mul_kernel", [&] { \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_
vectorize
1<scalar_t, KERNEL<scalar_t>, 2> \
vllm::act_and_mul_kernel_
opt
1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_
vectorize
1<scalar_t, KERNEL<scalar_t>, 8> \
vllm::act_and_mul_kernel_
opt
1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_
vectorize
1<scalar_t, KERNEL<scalar_t>, 8> \
vllm::act_and_mul_kernel_
opt
1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_
vectorize
1<scalar_t, KERNEL<scalar_t>, 8> \
vllm::act_and_mul_kernel_
opt
1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} else { \
} else { \
vllm::act_and_mul_kernel_
vectorize
2<scalar_t, KERNEL<scalar_t>, 8> \
vllm::act_and_mul_kernel_
opt
2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} \
} \
...
@@ -165,64 +165,3 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
...
@@ -165,64 +165,3 @@ void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
LAUNCH_ACTIVATION_GATE_KERNEL
(
vllm
::
gelu_tanh_kernel
);
}
}
namespace
vllm
{
// Element-wise activation kernel template.
template
<
typename
scalar_t
,
scalar_t
(
*
ACT_FN
)(
const
scalar_t
&
)>
__global__
void
activation_kernel
(
scalar_t
*
__restrict__
out
,
// [..., d]
const
scalar_t
*
__restrict__
input
,
// [..., d]
const
int
d
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
for
(
int64_t
idx
=
threadIdx
.
x
;
idx
<
d
;
idx
+=
blockDim
.
x
)
{
const
scalar_t
x
=
VLLM_LDG
(
&
input
[
token_idx
*
d
+
idx
]);
out
[
token_idx
*
d
+
idx
]
=
ACT_FN
(
x
);
}
}
}
// namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
namespace
vllm
{
template
<
typename
T
>
__device__
__forceinline__
T
gelu_new_kernel
(
const
T
&
x
)
{
const
float
x3
=
(
float
)(
x
*
x
*
x
);
const
T
t
=
(
T
)
tanhf
((
T
)(
0.79788456
f
*
(
float
)(
x
+
(
T
)(
0.044715
f
*
x3
))));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
template
<
typename
T
>
__device__
__forceinline__
T
gelu_fast_kernel
(
const
T
&
x
)
{
const
float
f
=
(
float
)
x
;
const
T
t
=
(
T
)
tanhf
(((
T
)(
f
*
0.79788456
f
))
*
(((
T
)
1.0
)
+
(
T
)(
0.044715
f
*
f
)
*
x
));
return
((
T
)
0.5
)
*
x
*
(((
T
)
1.0
)
+
t
);
}
}
// namespace vllm
void
gelu_new
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_new_kernel
);
}
void
gelu_fast
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
)
// [..., d]
{
LAUNCH_ACTIVATION_KERNEL
(
vllm
::
gelu_fast_kernel
);
}
csrc/opt/layernorm_kernels_opt.cu
View file @
bd93e661
...
@@ -323,7 +323,7 @@ __inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
...
@@ -323,7 +323,7 @@ __inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_
eval
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
__global__
void
fused_add_rms_kernel_
opt
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
val_shared
[
share_size
];
...
@@ -363,7 +363,7 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
...
@@ -363,7 +363,7 @@ __global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,sca
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_rms_kernel_
eval
(
scalar_t
*
input
,
scalar_t
*
output
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
__global__
void
fused_rms_kernel_
opt
(
scalar_t
*
input
,
scalar_t
*
output
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
val_shared
[
share_size
];
...
@@ -422,24 +422,24 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
...
@@ -422,24 +422,24 @@ void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
scalar_t
*
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
*
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
if
(
hidden_size
<=
1024
){
fused_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
fused_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
2048
){
else
if
(
hidden_size
<=
2048
){
fused_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
fused_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
4096
){
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
if
(
num_tokens
>
1200
){
fused_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
fused_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
{
else
{
fused_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
fused_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
}
}
else
if
(
hidden_size
<=
8192
){
else
if
(
hidden_size
<=
8192
){
fused_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
fused_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
{
else
{
fused_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
fused_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
});
});
}
}
...
@@ -492,24 +492,24 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
...
@@ -492,24 +492,24 @@ void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
scalar_t
*
other_data
=
residual
.
data_ptr
<
scalar_t
>
();
scalar_t
*
other_data
=
residual
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
if
(
hidden_size
<=
1024
){
fused_add_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
fused_add_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
2048
){
else
if
(
hidden_size
<=
2048
){
fused_add_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
fused_add_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
4096
){
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
if
(
num_tokens
>
1200
){
fused_add_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
fused_add_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
{
else
{
fused_add_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
fused_add_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
}
}
else
if
(
hidden_size
<=
8192
){
else
if
(
hidden_size
<=
8192
){
fused_add_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
fused_add_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
{
else
{
fused_add_rms_kernel_
eval
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
fused_add_rms_kernel_
opt
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
});
});
}
}
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
bd93e661
...
@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -340,7 +340,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prompt, and they have the same length.
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>=
4096
:
if
prefill_meta
.
max_prefill_seq_len
>=
8000
:
out
=
self
.
attn_func_triton
(
out
=
self
.
attn_func_triton
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
...
...
vllm/worker/model_runner.py
View file @
bd93e661
...
@@ -808,7 +808,10 @@ class ModelRunner:
...
@@ -808,7 +808,10 @@ class ModelRunner:
import
vllm.envs
as
envs
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
if
envs
.
VLLM_USE_FLASH_ATTN_AUTO
:
for
group_id
in
range
(
1
):
for
group_id
in
range
(
1
):
seq_len
=
8000
if
max_num_batched_tokens
>=
8000
:
seq_len
=
8000
else
:
seq_len
=
max_num_batched_tokens
if
vlm_config
is
None
:
if
vlm_config
is
None
:
seq_data
=
SequenceData
([
0
]
*
seq_len
)
seq_data
=
SequenceData
([
0
]
*
seq_len
)
dummy_multi_modal_data
=
None
dummy_multi_modal_data
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment