Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e053941
Commit
9e053941
authored
Mar 19, 2025
by
zhuwenwen
Browse files
skip fp8 kernel and _rocm_C extension
parent
f850f22a
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
992 additions
and
985 deletions
+992
-985
CMakeLists.txt
CMakeLists.txt
+6
-4
cmake/utils.cmake
cmake/utils.cmake
+1
-1
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+655
-655
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+1
-1
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+237
-235
csrc/ops.h
csrc/ops.h
+15
-15
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+3
-1
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+1
-1
csrc/quantization/gptq/compat.cuh
csrc/quantization/gptq/compat.cuh
+12
-12
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+30
-30
setup.py
setup.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-25
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+4
-3
No files found.
CMakeLists.txt
View file @
9e053941
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
#
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
#
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
#
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/advance_step.cu"
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3
USE_SABI 3
WITH_SOABI
)
WITH_SOABI
)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "HIP")
#
#
# _rocm_C extension
# _rocm_C extension
...
@@ -631,6 +632,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
...
@@ -631,6 +632,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
USE_SABI 3
WITH_SOABI)
WITH_SOABI)
endif()
endif()
]]
# For CUDA we also build and ship some external projects.
# For CUDA we also build and ship some external projects.
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
...
...
cmake/utils.cmake
View file @
9e053941
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DUSE_ROCM"
"-DENABLE_FP8"
#
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
)
"-fno-gpu-rdc"
)
...
...
csrc/attention/attention_kernels.cuh
View file @
9e053941
...
@@ -17,43 +17,43 @@
...
@@ -17,43 +17,43 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include <torch/all.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define WARP_SIZE 32
#define WARP_SIZE 32
#else
#else
#define WARP_SIZE warpSize
#define WARP_SIZE warpSize
#endif
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
namespace
vllm
{
// Utility function for attention softmax.
// Utility function for attention softmax.
template
<
int
NUM_WARPS
>
template
<
int
NUM_WARPS
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
// Compute the sum per warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
}
...
@@ -72,22 +72,22 @@ inline __device__ float block_sum(float* red_smem, float sum) {
...
@@ -72,22 +72,22 @@ inline __device__ float block_sum(float* red_smem, float sum) {
}
}
// Parallel reduction inside the warp.
// Parallel reduction inside the warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
}
// Broadcast to other threads.
// Broadcast to other threads.
return
VLLM_SHFL_SYNC
(
sum
,
0
);
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -178,7 +178,7 @@ __device__ void paged_attention_kernel(
...
@@ -178,7 +178,7 @@ __device__ void paged_attention_kernel(
// q is split from a qkv tensor, it may not be contiguous.
// q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
...
@@ -268,7 +268,7 @@ __device__ void paged_attention_kernel(
...
@@ -268,7 +268,7 @@ __device__ void paged_attention_kernel(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
k_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -310,7 +310,7 @@ __device__ void paged_attention_kernel(
...
@@ -310,7 +310,7 @@ __device__ void paged_attention_kernel(
// Perform reduction across the threads in the same warp to get the
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
}
...
@@ -322,7 +322,7 @@ __device__ void paged_attention_kernel(
...
@@ -322,7 +322,7 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part.
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
// Get the max qk value for the sequence.
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
}
...
@@ -370,7 +370,7 @@ __device__ void paged_attention_kernel(
...
@@ -370,7 +370,7 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
float
accs
[
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
accs
[
i
]
=
0.
f
;
}
}
...
@@ -401,7 +401,7 @@ __device__ void paged_attention_kernel(
...
@@ -401,7 +401,7 @@ __device__ void paged_attention_kernel(
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
kv_head_idx
*
kv_head_stride
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
if
(
row_idx
<
HEAD_SIZE
)
{
...
@@ -423,7 +423,7 @@ __device__ void paged_attention_kernel(
...
@@ -423,7 +423,7 @@ __device__ void paged_attention_kernel(
// contain NaNs. See
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
...
@@ -434,10 +434,10 @@ __device__ void paged_attention_kernel(
...
@@ -434,10 +434,10 @@ __device__ void paged_attention_kernel(
}
}
// Perform reduction within each warp.
// Perform reduction within each warp.
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
float
acc
=
accs
[
i
];
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
}
...
@@ -450,13 +450,13 @@ __device__ void paged_attention_kernel(
...
@@ -450,13 +450,13 @@ __device__ void paged_attention_kernel(
// Perform reduction across warps.
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
...
@@ -469,7 +469,7 @@ __device__ void paged_attention_kernel(
...
@@ -469,7 +469,7 @@ __device__ void paged_attention_kernel(
// Lower warps update the output.
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
...
@@ -485,7 +485,7 @@ __device__ void paged_attention_kernel(
...
@@ -485,7 +485,7 @@ __device__ void paged_attention_kernel(
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
...
@@ -493,13 +493,13 @@ __device__ void paged_attention_kernel(
...
@@ -493,13 +493,13 @@ __device__ void paged_attention_kernel(
}
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
bool
IS_BLOCK_SPARSE
>
__global__
void
paged_attention_v1_kernel
(
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
...
@@ -524,14 +524,14 @@ __global__ void paged_attention_v1_kernel(
...
@@ -524,14 +524,14 @@ __global__ void paged_attention_v1_kernel(
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -559,12 +559,12 @@ __global__ void paged_attention_v2_kernel(
...
@@ -559,12 +559,12 @@ __global__ void paged_attention_v2_kernel(
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -617,7 +617,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -617,7 +617,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Get the global max logit.
// Get the global max logit.
// Reduce within the warp.
// Reduce within the warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
}
...
@@ -627,7 +627,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -627,7 +627,7 @@ __global__ void paged_attention_v2_reduce_kernel(
__syncthreads
();
__syncthreads
();
// Reduce across warps.
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
}
...
@@ -657,7 +657,7 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -657,7 +657,7 @@ __global__ void paged_attention_v2_reduce_kernel(
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
...
@@ -666,11 +666,11 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -666,11 +666,11 @@ __global__ void paged_attention_v2_reduce_kernel(
}
}
from_float
(
out_ptr
[
i
],
acc
);
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
}
}
// namespace vllm
}
// namespace vllm
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/cache_kernels.cu
View file @
9e053941
csrc/layernorm_quant_kernels.cu
View file @
9e053941
...
@@ -5,24 +5,26 @@
...
@@ -5,24 +5,26 @@
* Currently, only static fp8 quantization is supported.
* Currently, only static fp8 quantization is supported.
*/
*/
#include "type_convert.cuh"
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#ifndef USE_ROCM
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#else
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/hipcub.hpp>
#endif
#endif
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
__global__
void
rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
@@ -54,15 +56,15 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -54,15 +56,15 @@ __global__ void rms_norm_static_fp8_quant_kernel(
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
/* Function specialization in the case of FP16/BF16 tensors.
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
packed and vectorized operations, which help with the
memory latency bottleneck. */
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fused_add_rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
...
@@ -111,20 +113,20 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -111,20 +113,20 @@ fused_add_rms_norm_static_fp8_quant_kernel(
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
temp
*=
s_variance
;
temp
*=
s_variance
;
temp
*=
weight_v
[
idx
];
temp
*=
weight_v
[
idx
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
out
[
id
*
width
+
i
]
=
out
[
id
*
width
+
i
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
}
}
}
/* Generic fused_add_rms_norm_kernel
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
The width field is not used here but necessary for other specializations.
*/
*/
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fused_add_rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
...
@@ -160,11 +162,11 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -160,11 +162,11 @@ fused_add_rms_norm_static_fp8_quant_kernel(
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
}
// namespace vllm
}
// namespace vllm
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scale
,
// [1]
torch
::
Tensor
&
scale
,
// [1]
...
@@ -187,9 +189,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -187,9 +189,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon
,
num_tokens
,
hidden_size
);
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
VLLM_DISPATCH_FP8_TYPES( \
...
@@ -203,7 +205,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -203,7 +205,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon, num_tokens, hidden_size); \
epsilon, num_tokens, hidden_size); \
}); \
}); \
});
});
void
fused_add_rms_norm_static_fp8_quant
(
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
...
@@ -239,4 +241,4 @@ void fused_add_rms_norm_static_fp8_quant(
...
@@ -239,4 +241,4 @@ void fused_add_rms_norm_static_fp8_quant(
}
else
{
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
}
\ No newline at end of file
csrc/ops.h
View file @
9e053941
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
//
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
//
torch::Tensor& weight, torch::Tensor& scale,
double
epsilon
);
//
double epsilon);
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
//
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch
::
Tensor
&
input
,
//
torch::Tensor& input,
torch
::
Tensor
&
residual
,
//
torch::Tensor& residual,
torch
::
Tensor
&
weight
,
//
torch::Tensor& weight,
torch
::
Tensor
&
scale
,
double
epsilon
);
//
torch::Tensor& scale, double epsilon);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
const
&
scale
);
//
torch::Tensor const& scale);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
&
scale
);
//
torch::Tensor& scale);
void
dynamic_per_token_scaled_fp8_quant
(
//
void dynamic_per_token_scaled_fp8_quant(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
//
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
//
std::optional<torch::Tensor> const& scale_ub);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
9e053941
#pragma once
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
9e053941
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
//
#include "quantization/fp8/common.cuh"
namespace
vllm
{
namespace
vllm
{
...
...
csrc/quantization/gptq/compat.cuh
View file @
9e053941
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
//
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
//
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half
(
address
,
val
);
//
atomicAdd_half(address, val);
}
//
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2
(
address
,
val
);
//
atomicAdd_half2(address, val);
}
//
}
#endif
//
#endif
#endif
//
#endif
#endif
//
#endif
}
// namespace gptq
}
// namespace gptq
}
// namespace vllm
}
// namespace vllm
...
...
csrc/torch_bindings.cpp
View file @
9e053941
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
//
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
//
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
//
"Tensor scale, float epsilon) -> "
"()"
);
//
"()");
ops
.
impl
(
"rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&
rms_norm_static_fp8_quant
);
//
&rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
// In-place fused Add and RMS Normalization.
ops
.
def
(
//
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
//
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
//
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"
);
//
"Tensor scale, float epsilon) -> ()");
ops
.
impl
(
"fused_add_rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&
fused_add_rms_norm_static_fp8_quant
);
//
&fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
// Fused Layernorm + Quant kernels
ops
.
def
(
ops
.
def
(
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Compute FP8 quantized tensor for given scaling factor.
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
//
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"
);
//
"()");
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
//
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
//
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops
.
def
(
//
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
//
"-> "
"()"
);
//
"()");
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
//
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops
.
def
(
//
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
//
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
//
"Tensor! scale, Tensor? scale_ub) -> "
"()"
);
//
"()");
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&
dynamic_per_token_scaled_fp8_quant
);
//
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
ops
.
def
(
...
...
setup.py
View file @
9e053941
...
@@ -643,8 +643,8 @@ ext_modules = []
...
@@ -643,8 +643,8 @@ ext_modules = []
if
_is_cuda
()
or
_is_hip
():
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
_is_hip
():
#
if _is_hip():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
#
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if
_is_cuda
():
if
_is_cuda
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
...
...
vllm/_custom_ops.py
View file @
9e053941
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_rocm
(
#
def paged_attention_rocm(
out
:
torch
.
Tensor
,
#
out: torch.Tensor,
exp_sum
:
torch
.
Tensor
,
#
exp_sum: torch.Tensor,
max_logits
:
torch
.
Tensor
,
#
max_logits: torch.Tensor,
tmp_out
:
torch
.
Tensor
,
#
tmp_out: torch.Tensor,
query
:
torch
.
Tensor
,
#
query: torch.Tensor,
key_cache
:
torch
.
Tensor
,
#
key_cache: torch.Tensor,
value_cache
:
torch
.
Tensor
,
#
value_cache: torch.Tensor,
num_kv_heads
:
int
,
#
num_kv_heads: int,
scale
:
float
,
#
scale: float,
block_tables
:
torch
.
Tensor
,
#
block_tables: torch.Tensor,
seq_lens
:
torch
.
Tensor
,
#
seq_lens: torch.Tensor,
block_size
:
int
,
#
block_size: int,
max_seq_len
:
int
,
#
max_seq_len: int,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
#
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype
:
str
,
#
kv_cache_dtype: str,
k_scale
:
torch
.
Tensor
,
#
k_scale: torch.Tensor,
v_scale
:
torch
.
Tensor
,
#
v_scale: torch.Tensor,
)
->
None
:
#
) -> None:
torch
.
ops
.
_rocm_C
.
paged_attention
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
#
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache
,
value_cache
,
num_kv_heads
,
#
key_cache, value_cache, num_kv_heads,
scale
,
block_tables
,
seq_lens
,
#
scale, block_tables, seq_lens,
block_size
,
max_seq_len
,
alibi_slopes
,
#
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype
,
k_scale
,
v_scale
)
#
kv_cache_dtype, k_scale, v_scale)
# pos encoding ops
# pos encoding ops
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
9e053941
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
use_custom
=
_use_rocm_custom_paged_attention
(
# use_custom = _use_rocm_custom_paged_attention(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
# decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta
.
max_decode_seq_len
)
# decode_meta.max_decode_seq_len)
use_custom
=
False
if
use_custom
:
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
!=
AttentionType
.
ENCODER_DECODER
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment