Commit 9e053941 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 kernel and _rocm_C extension

parent f850f22a
...@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC ...@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu" # "csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" # "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu" "csrc/prepare_inputs/advance_step.cu"
...@@ -613,6 +613,7 @@ define_gpu_extension_target( ...@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "HIP")
# #
# _rocm_C extension # _rocm_C extension
...@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ...@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
endif() endif()
]]
# For CUDA we also build and ship some external projects. # For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA") if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake) include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/vllm_flash_attn.cmake) include(cmake/external_projects/vllm_flash_attn.cmake)
endif () endif ()
\ No newline at end of file
...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
"-DENABLE_FP8" #"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc") "-fno-gpu-rdc")
......
This diff is collapsed.
...@@ -728,4 +728,4 @@ void gather_cache( ...@@ -728,4 +728,4 @@ void gather_cache(
} else { } else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
} }
} }
\ No newline at end of file
This diff is collapsed.
...@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ...@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, // torch::Tensor& weight, torch::Tensor& scale,
double epsilon); // double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, // void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input, // torch::Tensor& input,
torch::Tensor& residual, // torch::Tensor& residual,
torch::Tensor& weight, // torch::Tensor& weight,
torch::Tensor& scale, double epsilon); // torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out, void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input, torch::Tensor const& input,
...@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale); // torch::Tensor const& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale); // torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant( // void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, // torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub); // std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& A, const torch::Tensor& B,
......
#pragma once #pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h> #include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} // namespace fp8 } // namespace fp8
#endif // USE_ROCM #endif // USE_ROCM
} // namespace vllm } // namespace vllm
\ No newline at end of file
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh" #include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead // TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh" // #include "quantization/fp8/common.cuh"
namespace vllm { namespace vllm {
......
...@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { ...@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
// //
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) // #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { // __device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val); // atomicAdd_half(address, val);
} // }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { // __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val); // atomicAdd_half2(address, val);
} // }
#endif // #endif
#endif // #endif
#endif // #endif
} // namespace gptq } // namespace gptq
} // namespace vllm } // namespace vllm
......
...@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( // ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " // "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> " // "Tensor scale, float epsilon) -> "
"()"); // "()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, // ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant); // &rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization. // In-place fused Add and RMS Normalization.
ops.def( // ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " // "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, " // "Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"); // "Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, // ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant); // &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels // Fused Layernorm + Quant kernels
ops.def( ops.def(
...@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor. // Compute FP8 quantized tensor for given scaling factor.
ops.def( // ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " // "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"); // "()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def( // ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " // "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> " // "-> "
"()"); // "()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def( // ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " // "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> " // "Tensor! scale, Tensor? scale_ub) -> "
"()"); // "()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, // ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant); // &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor. // Compute int8 quantized tensor for given scaling factor.
ops.def( ops.def(
...@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { ...@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
} }
#endif #endif
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
\ No newline at end of file
...@@ -643,8 +643,8 @@ ext_modules = [] ...@@ -643,8 +643,8 @@ ext_modules = []
if _is_cuda() or _is_hip(): if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip(): # if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) # ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda(): if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
......
...@@ -98,30 +98,30 @@ def paged_attention_v2( ...@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_rocm( # def paged_attention_rocm(
out: torch.Tensor, # out: torch.Tensor,
exp_sum: torch.Tensor, # exp_sum: torch.Tensor,
max_logits: torch.Tensor, # max_logits: torch.Tensor,
tmp_out: torch.Tensor, # tmp_out: torch.Tensor,
query: torch.Tensor, # query: torch.Tensor,
key_cache: torch.Tensor, # key_cache: torch.Tensor,
value_cache: torch.Tensor, # value_cache: torch.Tensor,
num_kv_heads: int, # num_kv_heads: int,
scale: float, # scale: float,
block_tables: torch.Tensor, # block_tables: torch.Tensor,
seq_lens: torch.Tensor, # seq_lens: torch.Tensor,
block_size: int, # block_size: int,
max_seq_len: int, # max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], # alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, # kv_cache_dtype: str,
k_scale: torch.Tensor, # k_scale: torch.Tensor,
v_scale: torch.Tensor, # v_scale: torch.Tensor,
) -> None: # ) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, # torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, # key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens, # scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, # block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale) # kv_cache_dtype, k_scale, v_scale)
# pos encoding ops # pos encoding ops
...@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache( ...@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
) )
return out, softmax_lse return out, softmax_lse
\ No newline at end of file
...@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention( # use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio, # decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len) # decode_meta.max_decode_seq_len)
use_custom = False
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment