"vllm/vscode:/vscode.git/clone" did not exist on "4e51fa8cbaba2c6fd516b4615a533b0a94796516"
Commit fe79f042 authored by zhuwenwen's avatar zhuwenwen
Browse files

support rms_norm_dynamic_per_token_quant

parent b1085044
......@@ -248,7 +248,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
......
......@@ -107,13 +107,13 @@ void apply_repetition_penalties_(torch::Tensor& logits,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
// void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
// torch::Tensor const& input,
// torch::Tensor const& weight,
// torch::Tensor& scales,
// double const epsilon,
// std::optional<torch::Tensor> scale_ub,
// std::optional<torch::Tensor> residual);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
......
......@@ -6,6 +6,20 @@
#include "layernorm_utils.cuh"
#include "quant_conversions.cuh"
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
return true;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx94");
return substring == std::string::npos;
#endif
}
namespace vllm {
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
......
......@@ -195,12 +195,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
// ops.def(
// "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
// "Tensor weight, Tensor! scale, float epsilon, "
// "Tensor? scale_ub, Tensor!? residual) -> ()");
// ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
// &rms_norm_dynamic_per_token_quant);
ops.def(
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual) -> ()");
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
......
......@@ -62,15 +62,15 @@ class FixFunctionalizationPass(VllmInductorPass):
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
# elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'residual'}
# self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default,
# torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args)
......@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass):
node,
mutated_args,
args=('result', 'input'))
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: 'result'}
self.defunctionalize(graph,
node,
mutated_args,
args=('result', 'input', 'scale'))
# elif at_target == torch.ops._C.silu_and_mul_quant.default:
# mutated_args = {1: 'result'}
# self.defunctionalize(graph,
# node,
# mutated_args,
# args=('result', 'input', 'scale'))
else:
continue # skip the count
......
......@@ -82,9 +82,9 @@ class QuantKey(NamedTuple):
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment