Commit 753b29c0 authored by zhaosong's avatar zhaosong Committed by zhangzbb
Browse files

[bugfix][fp8]Enable torch.compile for quant_fp8.

parent 2444e959
......@@ -1900,6 +1900,28 @@ def scaled_fp4_experts_quant(
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
def _lightop_per_token_quant_fp8_impl(
out: torch.Tensor,
input: torch.Tensor,
scales: torch.Tensor,
) -> None:
from lightop import op
op.per_token_quant_fp8(out, input, scales)
def _lightop_per_token_quant_fp8_fake(
out: torch.Tensor,
input: torch.Tensor,
scales: torch.Tensor,
) -> None:
pass
direct_register_custom_op(
"lightop_per_token_quant_fp8",
_lightop_per_token_quant_fp8_impl,
mutates_args=["out", "scales"],
fake_impl=_lightop_per_token_quant_fp8_fake,
)
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
......@@ -1952,7 +1974,13 @@ def scaled_fp8_quant(
dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output, scale = per_token_quant_fp8(input.contiguous())
# output, scale = per_token_quant_fp8(input.contiguous())
output = torch.empty_like(input, device=input.device, dtype=torch.float8_e4m3fn)
scale = torch.empty(shape[:-1] + (1, ),
device=input.device,
dtype=torch.float32)
torch.ops.vllm.lightop_per_token_quant_fp8(output, input, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
......
......@@ -52,7 +52,7 @@ class QuantFP8(CustomOp):
column major format
:param compile_native: Manually compile forward_native if compile mode > None
"""
super().__init__(compile_native=compile_native)
super().__init__(compile_native=compile_native, enforce_enable=True)
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment