Commit a0ac95b0 authored by wanghl6's avatar wanghl6
Browse files

per_token_group_quant_fp8 opt

parent cb68935c
......@@ -915,6 +915,37 @@ def _per_token_group_quant_fp8_colmajor(
tl.store(y_s_ptr, y_s)
def _lightop_per_token_group_quant_fp8_impl(
x_q: torch.Tensor,
x: torch.Tensor,
x_s: torch.Tensor,
group_size: int,
eps: float,
use_ue8m0: bool,
) -> None:
from lightop import op
op.per_token_group_quant_fp8(x_q, x, x_s, group_size, eps, use_ue8m0)
def _lightop_per_token_group_quant_fp8_fake(
x_q: torch.Tensor,
x: torch.Tensor,
x_s: torch.Tensor,
group_size: int,
eps: float,
use_ue8m0: bool,
) -> None:
pass
direct_register_custom_op(
"lightop_per_token_group_quant_fp8",
_lightop_per_token_group_quant_fp8_impl,
mutates_args=["x_q", "x_s"],
fake_impl=_lightop_per_token_group_quant_fp8_fake,
)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
......@@ -980,7 +1011,11 @@ def per_token_group_quant_fp8(
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
if envs.USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8 and not column_major_scales:
torch.ops.vllm.lightop_per_token_group_quant_fp8(x_q, x, x_s, group_size, eps, use_ue8m0)
return x_q, x_s
# prefer CUDA kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
......@@ -1743,4 +1778,4 @@ def process_fp8_input_tensor_strategy_moe(
"for each layer."
)
return w13_input_scale.max(), w2_input_scale.max()
return w13_input_scale.max(), w2_input_scale.max()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment