"vscode:/vscode.git/clone" did not exist on "18016a5e627d2a4b69af599272a5aa8ce71b98c8"
Commit 9e053941 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 kernel and _rocm_C extension

parent f850f22a
......@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
# "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
......@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
......@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
WITH_SOABI)
endif()
]]
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/vllm_flash_attn.cmake)
endif ()
endif ()
\ No newline at end of file
......@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8"
#"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
......
This diff is collapsed.
......@@ -728,4 +728,4 @@ void gather_cache(
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
}
\ No newline at end of file
This diff is collapsed.
......@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
......@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub);
// void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
......
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
......@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} // namespace fp8
#endif // USE_ROCM
} // namespace vllm
} // namespace vllm
\ No newline at end of file
......@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
// #include "quantization/fp8/common.cuh"
namespace vllm {
......
......@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
// #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
// #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val);
}
// __device__ __forceinline__ void atomicAdd(half* address, half val) {
// atomicAdd_half(address, val);
// }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val);
}
#endif
// #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
// __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
// atomicAdd_half2(address, val);
// }
// #endif
#endif
#endif
// #endif
// #endif
} // namespace gptq
} // namespace vllm
......
......@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
"()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant);
// ops.def(
// "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
// "Tensor scale, float epsilon) -> "
// "()");
// ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
// &rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant);
// ops.def(
// "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! residual, Tensor weight, "
// "Tensor scale, float epsilon) -> ()");
// ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
// &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
ops.def(
......@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// ops.def(
// "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
// "()");
// ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> "
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> "
// "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
......@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
}
#endif
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
\ No newline at end of file
......@@ -643,8 +643,8 @@ ext_modules = []
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
# if _is_hip():
# ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
......
......@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_rocm(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
# def paged_attention_rocm(
# out: torch.Tensor,
# exp_sum: torch.Tensor,
# max_logits: torch.Tensor,
# tmp_out: torch.Tensor,
# query: torch.Tensor,
# key_cache: torch.Tensor,
# value_cache: torch.Tensor,
# num_kv_heads: int,
# scale: float,
# block_tables: torch.Tensor,
# seq_lens: torch.Tensor,
# block_size: int,
# max_seq_len: int,
# alibi_slopes: Optional[torch.Tensor],
# kv_cache_dtype: str,
# k_scale: torch.Tensor,
# v_scale: torch.Tensor,
# ) -> None:
# torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
# key_cache, value_cache, num_kv_heads,
# scale, block_tables, seq_lens,
# block_size, max_seq_len, alibi_slopes,
# kv_cache_dtype, k_scale, v_scale)
# pos encoding ops
......@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse
return out, softmax_lse
\ No newline at end of file
......@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
# use_custom = _use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len)
use_custom = False
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment