Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e053941
Commit
9e053941
authored
Mar 19, 2025
by
zhuwenwen
Browse files
skip fp8 kernel and _rocm_C extension
parent
f850f22a
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
992 additions
and
985 deletions
+992
-985
CMakeLists.txt
CMakeLists.txt
+6
-4
cmake/utils.cmake
cmake/utils.cmake
+1
-1
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+655
-655
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+1
-1
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+237
-235
csrc/ops.h
csrc/ops.h
+15
-15
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+3
-1
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+1
-1
csrc/quantization/gptq/compat.cuh
csrc/quantization/gptq/compat.cuh
+12
-12
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+30
-30
setup.py
setup.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-25
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+4
-3
No files found.
CMakeLists.txt
View file @
9e053941
...
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
#
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
#
"csrc/quantization/fp8/common.cu"
#
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
...
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI
)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
...
...
@@ -631,6 +632,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
WITH_SOABI)
endif()
]]
# For CUDA we also build and ship some external projects.
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
...
...
cmake/utils.cmake
View file @
9e053941
...
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8"
#
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
)
...
...
csrc/attention/attention_kernels.cuh
View file @
9e053941
...
...
@@ -17,43 +17,43 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#else
#define WARP_SIZE warpSize
#endif
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
namespace
vllm
{
// Utility function for attention softmax.
template
<
int
NUM_WARPS
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Utility function for attention softmax.
template
<
int
NUM_WARPS
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
...
...
@@ -72,22 +72,22 @@ inline __device__ float block_sum(float* red_smem, float sum) {
}
// Parallel reduction inside the warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
// Broadcast to other threads.
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -178,7 +178,7 @@ __device__ void paged_attention_kernel(
// q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
...
...
@@ -268,7 +268,7 @@ __device__ void paged_attention_kernel(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
...
...
@@ -310,7 +310,7 @@ __device__ void paged_attention_kernel(
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
...
...
@@ -322,7 +322,7 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
...
...
@@ -370,7 +370,7 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
}
...
...
@@ -401,7 +401,7 @@ __device__ void paged_attention_kernel(
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
...
...
@@ -423,7 +423,7 @@ __device__ void paged_attention_kernel(
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
...
...
@@ -434,10 +434,10 @@ __device__ void paged_attention_kernel(
}
// Perform reduction within each warp.
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
...
...
@@ -450,13 +450,13 @@ __device__ void paged_attention_kernel(
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
...
...
@@ -469,7 +469,7 @@ __device__ void paged_attention_kernel(
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
...
...
@@ -485,7 +485,7 @@ __device__ void paged_attention_kernel(
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
...
...
@@ -493,13 +493,13 @@ __device__ void paged_attention_kernel(
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
__global__
void
paged_attention_v1_kernel
(
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
...
...
@@ -524,14 +524,14 @@ __global__ void paged_attention_v1_kernel(
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -559,12 +559,12 @@ __global__ void paged_attention_v2_kernel(
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -617,7 +617,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
...
...
@@ -627,7 +627,7 @@ __global__ void paged_attention_v2_reduce_kernel(
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
...
...
@@ -657,7 +657,7 @@ __global__ void paged_attention_v2_reduce_kernel(
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
...
...
@@ -666,11 +666,11 @@ __global__ void paged_attention_v2_reduce_kernel(
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
}
// namespace vllm
}
// namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/cache_kernels.cu
View file @
9e053941
csrc/layernorm_quant_kernels.cu
View file @
9e053941
...
...
@@ -5,24 +5,26 @@
* Currently, only static fp8 quantization is supported.
*/
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#else
#include <hipcub/hipcub.hpp>
#endif
#endif
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
...
...
@@ -54,15 +56,15 @@ __global__ void rms_norm_static_fp8_quant_kernel(
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
/* Function specialization in the case of FP16/BF16 tensors.
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
...
...
@@ -111,20 +113,20 @@ fused_add_rms_norm_static_fp8_quant_kernel(
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
temp
*=
s_variance
;
temp
*=
weight_v
[
idx
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
out
[
id
*
width
+
i
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
}
/* Generic fused_add_rms_norm_kernel
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
...
...
@@ -160,11 +162,11 @@ fused_add_rms_norm_static_fp8_quant_kernel(
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
// namespace vllm
}
// namespace vllm
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
scale
,
// [1]
...
...
@@ -187,9 +189,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon
,
num_tokens
,
hidden_size
);
});
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
...
...
@@ -203,7 +205,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon, num_tokens, hidden_size); \
}); \
});
void
fused_add_rms_norm_static_fp8_quant
(
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
...
...
@@ -239,4 +241,4 @@ void fused_add_rms_norm_static_fp8_quant(
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
\ No newline at end of file
csrc/ops.h
View file @
9e053941
...
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
double
epsilon
);
//
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
//
torch::Tensor& weight, torch::Tensor& scale,
//
double epsilon);
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
double
epsilon
);
//
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
//
torch::Tensor& input,
//
torch::Tensor& residual,
//
torch::Tensor& weight,
//
torch::Tensor& scale, double epsilon);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//
torch::Tensor const& scale);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
);
//
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//
torch::Tensor& scale);
void
dynamic_per_token_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
//
void dynamic_per_token_scaled_fp8_quant(
//
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
//
std::optional<torch::Tensor> const& scale_ub);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
9e053941
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
9e053941
...
...
@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
//
#include "quantization/fp8/common.cuh"
namespace
vllm
{
...
...
csrc/quantization/gptq/compat.cuh
View file @
9e053941
...
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half* address, half val) {
//
atomicAdd_half(address, val);
//
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
#endif
//
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
//
atomicAdd_half2(address, val);
//
}
//
#endif
#endif
#endif
//
#endif
//
#endif
}
// namespace gptq
}
// namespace vllm
...
...
csrc/torch_bindings.cpp
View file @
9e053941
...
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
&
rms_norm_static_fp8_quant
);
//
ops.def(
//
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
//
"Tensor scale, float epsilon) -> "
//
"()");
//
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
//
&rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
ops
.
def
(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
&
fused_add_rms_norm_static_fp8_quant
);
//
ops.def(
//
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
//
"Tensor! residual, Tensor weight, "
//
"Tensor scale, float epsilon) -> ()");
//
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
//
&fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
ops
.
def
(
...
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"
);
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops
.
def
(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()"
);
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
//
ops.def(
//
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//
"()");
//
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
//
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
//
ops.def(
//
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//
"-> "
//
"()");
//
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops
.
def
(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()"
);
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_per_token_scaled_fp8_quant
);
//
ops.def(
//
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
//
"Tensor! scale, Tensor? scale_ub) -> "
//
"()");
//
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
//
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
...
...
setup.py
View file @
9e053941
...
...
@@ -643,8 +643,8 @@ ext_modules = []
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
#
if _is_hip():
#
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if
_is_cuda
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
...
...
vllm/_custom_ops.py
View file @
9e053941
...
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_rocm
(
out
:
torch
.
Tensor
,
exp_sum
:
torch
.
Tensor
,
max_logits
:
torch
.
Tensor
,
tmp_out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
_rocm_C
.
paged_attention
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
)
#
def paged_attention_rocm(
#
out: torch.Tensor,
#
exp_sum: torch.Tensor,
#
max_logits: torch.Tensor,
#
tmp_out: torch.Tensor,
#
query: torch.Tensor,
#
key_cache: torch.Tensor,
#
value_cache: torch.Tensor,
#
num_kv_heads: int,
#
scale: float,
#
block_tables: torch.Tensor,
#
seq_lens: torch.Tensor,
#
block_size: int,
#
max_seq_len: int,
#
alibi_slopes: Optional[torch.Tensor],
#
kv_cache_dtype: str,
#
k_scale: torch.Tensor,
#
v_scale: torch.Tensor,
#
) -> None:
#
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
#
key_cache, value_cache, num_kv_heads,
#
scale, block_tables, seq_lens,
#
block_size, max_seq_len, alibi_slopes,
#
kv_cache_dtype, k_scale, v_scale)
# pos encoding ops
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
9e053941
...
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
use_custom
=
_use_rocm_custom_paged_attention
(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
decode_meta
.
max_decode_seq_len
)
# use_custom = _use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len)
use_custom
=
False
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment