Commit fe79f042 authored by zhuwenwen's avatar zhuwenwen
Browse files

support rms_norm_dynamic_per_token_quant

parent b1085044
...@@ -248,7 +248,7 @@ set(VLLM_EXT_SRC ...@@ -248,7 +248,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu" # "csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
......
...@@ -107,13 +107,13 @@ void apply_repetition_penalties_(torch::Tensor& logits, ...@@ -107,13 +107,13 @@ void apply_repetition_penalties_(torch::Tensor& logits,
// torch::Tensor& weight, // torch::Tensor& weight,
// torch::Tensor& scale, double epsilon); // torch::Tensor& scale, double epsilon);
// void rms_norm_dynamic_per_token_quant(torch::Tensor& out, void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
// torch::Tensor const& input, torch::Tensor const& input,
// torch::Tensor const& weight, torch::Tensor const& weight,
// torch::Tensor& scales, torch::Tensor& scales,
// double const epsilon, double const epsilon,
// std::optional<torch::Tensor> scale_ub, std::optional<torch::Tensor> scale_ub,
// std::optional<torch::Tensor> residual); std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
......
...@@ -6,6 +6,20 @@ ...@@ -6,6 +6,20 @@
#include "layernorm_utils.cuh" #include "layernorm_utils.cuh"
#include "quant_conversions.cuh" #include "quant_conversions.cuh"
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
return true;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx94");
return substring == std::string::npos;
#endif
}
namespace vllm { namespace vllm {
template <typename scalar_t, typename scalar_out_t, bool has_residual = false> template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
......
...@@ -195,12 +195,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -195,12 +195,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// &fused_add_rms_norm_static_fp8_quant); // &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels // Fused Layernorm + Quant kernels
// ops.def( ops.def(
// "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
// "Tensor weight, Tensor! scale, float epsilon, " "Tensor weight, Tensor! scale, float epsilon, "
// "Tensor? scale_ub, Tensor!? residual) -> ()"); "Tensor? scale_ub, Tensor!? residual) -> ()");
// ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
// &rms_norm_dynamic_per_token_quant); &rms_norm_dynamic_per_token_quant);
// Rotary embedding // Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
......
...@@ -62,15 +62,15 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -62,15 +62,15 @@ class FixFunctionalizationPass(VllmInductorPass):
elif at_target == torch.ops._C.fused_add_rms_norm.default: elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'} mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 # elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'} # mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args) # self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
elif at_target in [ elif at_target in [
torch.ops._C.rms_norm.default, torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default, # torch.ops._C.rms_norm_static_fp8_quant.default,
]: ]:
mutated_args = {1: 'result'} mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
...@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass):
node, node,
mutated_args, mutated_args,
args=('result', 'input')) args=('result', 'input'))
elif at_target == torch.ops._C.silu_and_mul_quant.default: # elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: 'result'} # mutated_args = {1: 'result'}
self.defunctionalize(graph, # self.defunctionalize(graph,
node, # node,
mutated_args, # mutated_args,
args=('result', 'input', 'scale')) # args=('result', 'input', 'scale'))
else: else:
continue # skip the count continue # skip the count
......
...@@ -82,9 +82,9 @@ class QuantKey(NamedTuple): ...@@ -82,9 +82,9 @@ class QuantKey(NamedTuple):
f"{'a' if not self.symmetric else ''}symmetric)") f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) # kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) # kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) # kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym: # kFp8StaticTensorSym:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment