Commit 91feb245 authored by zhuwenwen's avatar zhuwenwen
Browse files

support rms_norm_dynamic_per_token_quant

parent 18f030d9
...@@ -261,7 +261,7 @@ set(VLLM_EXT_SRC ...@@ -261,7 +261,7 @@ set(VLLM_EXT_SRC
# "csrc/quantization/gptq/q_gemm.cu" # "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu" # "csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
......
...@@ -233,13 +233,13 @@ void apply_repetition_penalties_(torch::Tensor& logits, ...@@ -233,13 +233,13 @@ void apply_repetition_penalties_(torch::Tensor& logits,
// torch::Tensor& weight, // torch::Tensor& weight,
// torch::Tensor& scale, double epsilon); // torch::Tensor& scale, double epsilon);
// void rms_norm_dynamic_per_token_quant(torch::Tensor& out, void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
// torch::Tensor const& input, torch::Tensor const& input,
// torch::Tensor const& weight, torch::Tensor const& weight,
// torch::Tensor& scales, torch::Tensor& scales,
// double const epsilon, double const epsilon,
// std::optional<torch::Tensor> scale_ub, std::optional<torch::Tensor> scale_ub,
// std::optional<torch::Tensor> residual); std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
......
...@@ -6,6 +6,20 @@ ...@@ -6,6 +6,20 @@
#include "layernorm_utils.cuh" #include "layernorm_utils.cuh"
#include "quant_conversions.cuh" #include "quant_conversions.cuh"
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
return true;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx94");
return substring == std::string::npos;
#endif
}
namespace vllm { namespace vllm {
template <typename scalar_t, typename scalar_out_t, bool has_residual = false> template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
......
...@@ -372,12 +372,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -372,12 +372,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// &fused_add_rms_norm_static_fp8_quant); // &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels // Fused Layernorm + Quant kernels
// ops.def( ops.def(
// "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
// "Tensor weight, Tensor! scale, float epsilon, " "Tensor weight, Tensor! scale, float epsilon, "
// "Tensor? scale_ub, Tensor!? residual) -> ()"); "Tensor? scale_ub, Tensor!? residual) -> ()");
// ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
// &rms_norm_dynamic_per_token_quant); &rms_norm_dynamic_per_token_quant);
// Rotary embedding // Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment