Unverified Commit f9c53cbb authored by strgrb's avatar strgrb Committed by GitHub
Browse files

Create col-major and tma-aligned x_scale for deep_gemm.gemm_fp8_fp8_bf16_nt (#4515)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent 90532b76
...@@ -168,6 +168,7 @@ def per_token_group_quant_fp8( ...@@ -168,6 +168,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_type_, dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
...@@ -200,11 +201,20 @@ def per_token_group_quant_fp8( ...@@ -200,11 +201,20 @@ def per_token_group_quant_fp8(
M = x.numel() // group_size M = x.numel() // group_size
N = group_size N = group_size
if column_major_scales: if column_major_scales:
x_s = torch.empty( if scale_tma_aligned:
(x.shape[-1] // group_size,) + x.shape[:-1], # aligned to 4 * sizeof(float)
device=x.device, aligned_size = (x.shape[-2] + 3) // 4 * 4
dtype=torch.float32, x_s = torch.empty(
).permute(-1, -2) x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else: else:
x_s = torch.empty( x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x.shape[:-1] + (x.shape[-1] // group_size,),
......
...@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple ...@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm,
per_token_group_quant_fp8, per_token_group_quant_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
...@@ -129,9 +130,17 @@ def apply_w8a8_block_fp8_linear( ...@@ -129,9 +130,17 @@ def apply_w8a8_block_fp8_linear(
) )
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else: else:
q_input, x_scale = per_token_group_quant_fp8( if _enable_jit_deepgemm:
input_2d, block_size[1], column_major_scales=False q_input, x_scale = per_token_group_quant_fp8(
) input_2d,
block_size[1],
column_major_scales=True,
scale_tma_aligned=True,
)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = w8a8_block_fp8_matmul( output = w8a8_block_fp8_matmul(
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
) )
......
Subproject commit bd2a77552886b98c205af12f8d7d2d61247c4b27 Subproject commit 3b3783d06cd4d06ac4ba048633e604151d1ee535
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment