Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e053941
Commit
9e053941
authored
Mar 19, 2025
by
zhuwenwen
Browse files
skip fp8 kernel and _rocm_C extension
parent
f850f22a
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
992 additions
and
985 deletions
+992
-985
CMakeLists.txt
CMakeLists.txt
+6
-4
cmake/utils.cmake
cmake/utils.cmake
+1
-1
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+655
-655
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+1
-1
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+237
-235
csrc/ops.h
csrc/ops.h
+15
-15
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+3
-1
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+1
-1
csrc/quantization/gptq/compat.cuh
csrc/quantization/gptq/compat.cuh
+12
-12
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+30
-30
setup.py
setup.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-25
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+4
-3
No files found.
CMakeLists.txt
View file @
9e053941
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
#
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
#
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
#
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/advance_step.cu"
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3
USE_SABI 3
WITH_SOABI
)
WITH_SOABI
)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "HIP")
#
#
# _rocm_C extension
# _rocm_C extension
...
@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
...
@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
USE_SABI 3
WITH_SOABI)
WITH_SOABI)
endif()
endif()
]]
# For CUDA we also build and ship some external projects.
# For CUDA we also build and ship some external projects.
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
include
(
cmake/external_projects/flashmla.cmake
)
include
(
cmake/external_projects/flashmla.cmake
)
include
(
cmake/external_projects/vllm_flash_attn.cmake
)
include
(
cmake/external_projects/vllm_flash_attn.cmake
)
endif
()
endif
()
\ No newline at end of file
cmake/utils.cmake
View file @
9e053941
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DUSE_ROCM"
"-DENABLE_FP8"
#
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
)
"-fno-gpu-rdc"
)
...
...
csrc/attention/attention_kernels.cuh
View file @
9e053941
This diff is collapsed.
Click to expand it.
csrc/cache_kernels.cu
View file @
9e053941
...
@@ -728,4 +728,4 @@ void gather_cache(
...
@@ -728,4 +728,4 @@ void gather_cache(
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type width: "
,
dtype_bits
);
TORCH_CHECK
(
false
,
"Unsupported data type width: "
,
dtype_bits
);
}
}
}
}
\ No newline at end of file
csrc/layernorm_quant_kernels.cu
View file @
9e053941
This diff is collapsed.
Click to expand it.
csrc/ops.h
View file @
9e053941
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
//
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
//
torch::Tensor& weight, torch::Tensor& scale,
double
epsilon
);
//
double epsilon);
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
//
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch
::
Tensor
&
input
,
//
torch::Tensor& input,
torch
::
Tensor
&
residual
,
//
torch::Tensor& residual,
torch
::
Tensor
&
weight
,
//
torch::Tensor& weight,
torch
::
Tensor
&
scale
,
double
epsilon
);
//
torch::Tensor& scale, double epsilon);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
const
&
scale
);
//
torch::Tensor const& scale);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
&
scale
);
//
torch::Tensor& scale);
void
dynamic_per_token_scaled_fp8_quant
(
//
void dynamic_per_token_scaled_fp8_quant(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
//
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
//
std::optional<torch::Tensor> const& scale_ub);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
9e053941
#pragma once
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
// namespace fp8
}
// namespace fp8
#endif // USE_ROCM
#endif // USE_ROCM
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
9e053941
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
//
#include "quantization/fp8/common.cuh"
namespace
vllm
{
namespace
vllm
{
...
...
csrc/quantization/gptq/compat.cuh
View file @
9e053941
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
//
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
//
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half
(
address
,
val
);
//
atomicAdd_half(address, val);
}
//
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2
(
address
,
val
);
//
atomicAdd_half2(address, val);
}
//
}
#endif
//
#endif
#endif
//
#endif
#endif
//
#endif
}
// namespace gptq
}
// namespace gptq
}
// namespace vllm
}
// namespace vllm
...
...
csrc/torch_bindings.cpp
View file @
9e053941
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
//
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
//
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
//
"Tensor scale, float epsilon) -> "
"()"
);
//
"()");
ops
.
impl
(
"rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&
rms_norm_static_fp8_quant
);
//
&rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
// In-place fused Add and RMS Normalization.
ops
.
def
(
//
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
//
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
//
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"
);
//
"Tensor scale, float epsilon) -> ()");
ops
.
impl
(
"fused_add_rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&
fused_add_rms_norm_static_fp8_quant
);
//
&fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
// Fused Layernorm + Quant kernels
ops
.
def
(
ops
.
def
(
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Compute FP8 quantized tensor for given scaling factor.
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
//
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"
);
//
"()");
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
//
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
//
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops
.
def
(
//
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
//
"-> "
"()"
);
//
"()");
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
//
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops
.
def
(
//
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
//
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
//
"Tensor! scale, Tensor? scale_ub) -> "
"()"
);
//
"()");
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&
dynamic_per_token_scaled_fp8_quant
);
//
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
ops
.
def
(
...
@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
...
@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
}
}
#endif
#endif
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
\ No newline at end of file
setup.py
View file @
9e053941
...
@@ -643,8 +643,8 @@ ext_modules = []
...
@@ -643,8 +643,8 @@ ext_modules = []
if
_is_cuda
()
or
_is_hip
():
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
_is_hip
():
#
if _is_hip():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
#
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if
_is_cuda
():
if
_is_cuda
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
...
...
vllm/_custom_ops.py
View file @
9e053941
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_rocm
(
#
def paged_attention_rocm(
out
:
torch
.
Tensor
,
#
out: torch.Tensor,
exp_sum
:
torch
.
Tensor
,
#
exp_sum: torch.Tensor,
max_logits
:
torch
.
Tensor
,
#
max_logits: torch.Tensor,
tmp_out
:
torch
.
Tensor
,
#
tmp_out: torch.Tensor,
query
:
torch
.
Tensor
,
#
query: torch.Tensor,
key_cache
:
torch
.
Tensor
,
#
key_cache: torch.Tensor,
value_cache
:
torch
.
Tensor
,
#
value_cache: torch.Tensor,
num_kv_heads
:
int
,
#
num_kv_heads: int,
scale
:
float
,
#
scale: float,
block_tables
:
torch
.
Tensor
,
#
block_tables: torch.Tensor,
seq_lens
:
torch
.
Tensor
,
#
seq_lens: torch.Tensor,
block_size
:
int
,
#
block_size: int,
max_seq_len
:
int
,
#
max_seq_len: int,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
#
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype
:
str
,
#
kv_cache_dtype: str,
k_scale
:
torch
.
Tensor
,
#
k_scale: torch.Tensor,
v_scale
:
torch
.
Tensor
,
#
v_scale: torch.Tensor,
)
->
None
:
#
) -> None:
torch
.
ops
.
_rocm_C
.
paged_attention
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
#
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache
,
value_cache
,
num_kv_heads
,
#
key_cache, value_cache, num_kv_heads,
scale
,
block_tables
,
seq_lens
,
#
scale, block_tables, seq_lens,
block_size
,
max_seq_len
,
alibi_slopes
,
#
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype
,
k_scale
,
v_scale
)
#
kv_cache_dtype, k_scale, v_scale)
# pos encoding ops
# pos encoding ops
...
@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
...
@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
\ No newline at end of file
vllm/attention/backends/rocm_flash_attn.py
View file @
9e053941
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
use_custom
=
_use_rocm_custom_paged_attention
(
# use_custom = _use_rocm_custom_paged_attention(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
# decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta
.
max_decode_seq_len
)
# decode_meta.max_decode_seq_len)
use_custom
=
False
if
use_custom
:
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
!=
AttentionType
.
ENCODER_DECODER
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment