Commit 572cd426 authored by zhuwenwen's avatar zhuwenwen
Browse files

restore the initial fp8 related implementation

parent c441dda9
......@@ -255,15 +255,15 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"
......
......@@ -123,7 +123,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
# "-DENABLE_FP8"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-Werror=unused-variable"
......
......@@ -6,9 +6,7 @@
*/
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include <torch/cuda.h>
......
......@@ -221,15 +221,15 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
......@@ -258,8 +258,8 @@ void rotary_embedding_tgi(
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
......@@ -443,15 +443,15 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);
// void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
......
......@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
// #include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
namespace vllm {
......
......@@ -258,8 +258,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// ops.def(
// "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
// ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
......@@ -737,25 +737,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
// ops.def(
// "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
// "()");
// ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> "
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> "
// "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant);
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
......
......@@ -1692,66 +1692,66 @@ def scaled_fp4_experts_quant(
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: Optional[torch.Tensor] = None,
# num_token_padding: Optional[int] = None,
# scale_ub: Optional[torch.Tensor] = None,
# use_per_token_if_dynamic: bool = False,
# output: Optional[torch.Tensor] = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
# shape: Union[tuple[int, int], torch.Size] = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, \
# "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1),
# device=input.device,
# dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
# else:
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# assert scale.numel() == 1, f"{scale.shape}"
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# return output, scale
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input.contiguous(), scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# gptq allspark
......
......@@ -62,9 +62,9 @@ class FixFunctionalizationPass(VllmInductorPass):
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
# elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'residual'}
# self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args)
......@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass):
node,
mutated_args,
args=('result', 'input'))
# elif at_target == torch.ops._C.silu_and_mul_quant.default:
# mutated_args = {1: 'result'}
# self.defunctionalize(graph,
# node,
# mutated_args,
# args=('result', 'input', 'scale'))
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: 'result'}
self.defunctionalize(graph,
node,
mutated_args,
args=('result', 'input', 'scale'))
else:
continue # skip the count
......
......@@ -82,17 +82,17 @@ class QuantKey(NamedTuple):
f"{'a' if not self.symmetric else ''}symmetric)")
# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym:
# torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
# kFp8DynamicTensorSym:
# torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
# kFp8DynamicTokenSym:
# torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
......@@ -111,14 +111,14 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
# FusedRMSQuantKey(kFp8StaticTensorSym, False):
# torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
# FusedRMSQuantKey(kFp8StaticTensorSym, True):
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
# FusedRMSQuantKey(kFp8DynamicTokenSym, False):
# torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
# FusedRMSQuantKey(kFp8DynamicTokenSym, True):
# torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
}
......@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
# RMSNormStaticQuantPattern(epsilon,
# FP8_DTYPE).register(self.patterns)
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
# FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns, self.record_match)
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
# RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns, self.record_match)
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
# FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns, self.record_match)
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
......
......@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns
# fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
# FirstAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns)
# MiddleAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns)
# LastAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns)
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
FirstAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
MiddleAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
LastAllReduceRMSNormStaticFP8Pattern(
epsilon, self.model_dtype, self.device,
fp8_quant_op).register(self.patterns)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment