"vscode:/vscode.git/clone" did not exist on "b7955ef17b8d899327b25564f20665ec3ffa71cb"
Unverified Commit d08b356e authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent f7448101
...@@ -14,7 +14,6 @@ from vllm.triton_utils import triton ...@@ -14,7 +14,6 @@ from vllm.triton_utils import triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
calc_diff, calc_diff,
fp8_gemm_nt, fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8, per_block_cast_to_fp8,
) )
...@@ -48,8 +47,9 @@ def benchmark_shape( ...@@ -48,8 +47,9 @@ def benchmark_shape(
block_size = [128, 128] block_size = [128, 128]
# Pre-quantize A for all implementations # Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) A, block_size[1], column_major_scales=True, tma_aligned_scales=True
)
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
......
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast, group_broadcast,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import _ceil_to_ue8m0, is_deep_gemm_e8m0_used
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -170,6 +171,8 @@ def native_per_token_group_quant_fp8( ...@@ -170,6 +171,8 @@ def native_per_token_group_quant_fp8(
x_ = x.reshape(x.numel() // group_size, group_size) x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max x_s = amax / fp8_max
if is_deep_gemm_e8m0_used():
x_s = _ceil_to_ue8m0(x_s)
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape) x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
get_col_major_tma_aligned_tensor, get_tma_aligned_size,
per_block_cast_to_fp8, per_block_cast_to_fp8,
should_use_deepgemm_for_fp8_linear, should_use_deepgemm_for_fp8_linear,
) )
...@@ -40,6 +40,8 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] ...@@ -40,6 +40,8 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 2050] NUM_TOKENS = [7, 2050]
D = [512, 4096, 5120, 13824] D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 512] GROUP_SIZE = [64, 128, 512]
COLUMN_MAJOR_SCALES = [True, False]
TMA_ALIGNED_SCALES = [True, False]
M = [1, 7, 8, 83, 84, 4096] M = [1, 7, 8, 83, 84, 4096]
N = [128, 512, 7168, 7748, 13824] N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384] K = [256, 3884, 4096, 13824, 16384]
...@@ -63,20 +65,40 @@ def setup_cuda(): ...@@ -63,20 +65,40 @@ def setup_cuda():
reason="This platform supports e4m3fnuz, not e4m3fn.", reason="This platform supports e4m3fnuz, not e4m3fn.",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed", "num_tokens,d,dtype,group_size,column_major_scales,tma_aligned_scales,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS), itertools.product(
NUM_TOKENS,
D,
DTYPES,
GROUP_SIZE,
COLUMN_MAJOR_SCALES,
TMA_ALIGNED_SCALES,
SEEDS,
),
) )
@torch.inference_mode() @torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): def test_per_token_group_quant_fp8(
num_tokens, d, dtype, group_size, column_major_scales, tma_aligned_scales, seed
):
torch.manual_seed(seed) torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype) x = torch.rand(num_tokens, d, dtype=dtype)
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size) out, scale = per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major_scales,
tma_aligned_scales=tma_aligned_scales,
)
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
assert torch.allclose(scale, ref_scale) assert torch.allclose(scale, ref_scale)
if column_major_scales:
assert scale.stride()[-2] == 1
if tma_aligned_scales:
assert scale.stride()[-1] == get_tma_aligned_size(num_tokens, 4)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed", "M,N,K,block_size,out_dtype,seed",
...@@ -186,7 +208,9 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -186,7 +208,9 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
): ):
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1]) A_fp8, As_fp8 = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=True, tma_aligned_scales=True
)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size) B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
As = As_fp8.to(torch.float32) As = As_fp8.to(torch.float32)
...@@ -194,9 +218,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -194,9 +218,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device="cuda", dtype=out_dtype) out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) // 128), ( assert As_fp8.shape == (M, (K + 127) // 128), (
......
...@@ -8,13 +8,16 @@ import torch ...@@ -8,13 +8,16 @@ import torch
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) @pytest.mark.parametrize(
"shape", [(31, 128), (32, 128), (63, 256), (64, 256), (16, 512)]
)
@pytest.mark.parametrize("column_major", [False, True]) @pytest.mark.parametrize("column_major", [False, True])
@pytest.mark.parametrize("tma_aligned", [False, True])
@pytest.mark.parametrize("scale_ue8m0", [False, True]) @pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8( def test_per_token_group_quant_fp8(
shape, column_major: bool, scale_ue8m0: bool, group_size: int shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
): ):
device = "cuda" device = "cuda"
...@@ -28,6 +31,7 @@ def test_per_token_group_quant_fp8( ...@@ -28,6 +31,7 @@ def test_per_token_group_quant_fp8(
x, x,
group_size, group_size,
column_major_scales=column_major, column_major_scales=column_major,
tma_aligned_scales=tma_aligned,
use_ue8m0=scale_ue8m0, use_ue8m0=scale_ue8m0,
) )
......
...@@ -36,6 +36,7 @@ class QuantFP8(CustomOp): ...@@ -36,6 +36,7 @@ class QuantFP8(CustomOp):
group_shape: GroupShape, group_shape: GroupShape,
num_token_padding: int | None = None, num_token_padding: int | None = None,
column_major_scales: bool = False, column_major_scales: bool = False,
tma_aligned_scales: bool = False,
use_ue8m0: bool | None = None, # for Torch compile use_ue8m0: bool | None = None, # for Torch compile
): ):
""" """
...@@ -44,6 +45,8 @@ class QuantFP8(CustomOp): ...@@ -44,6 +45,8 @@ class QuantFP8(CustomOp):
PER_CHANNEL, or arbitrary block size) PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this :param num_token_padding: Pad the token dimension of output to this
size size
:param tma_aligned_scales: For group quantization, output scales in
TMA-aligned layout
:param column_major_scales: For group quantization, output scales in :param column_major_scales: For group quantization, output scales in
column major format column major format
""" """
...@@ -53,6 +56,7 @@ class QuantFP8(CustomOp): ...@@ -53,6 +56,7 @@ class QuantFP8(CustomOp):
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales self.column_major_scales = column_major_scales
self.tma_aligned_scales = tma_aligned_scales
self.use_ue8m0 = use_ue8m0 self.use_ue8m0 = use_ue8m0
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled() self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
...@@ -82,6 +86,7 @@ class QuantFP8(CustomOp): ...@@ -82,6 +86,7 @@ class QuantFP8(CustomOp):
x, x,
group_size=self.group_size, group_size=self.group_size,
column_major_scales=self.column_major_scales, column_major_scales=self.column_major_scales,
tma_aligned_scales=self.tma_aligned_scales,
dtype=_FP8_DTYPE, dtype=_FP8_DTYPE,
use_ue8m0=self.use_ue8m0, use_ue8m0=self.use_ue8m0,
) )
......
...@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton ...@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT, DeepGemmQuantScaleFMT,
fp8_gemm_nt, fp8_gemm_nt,
get_tma_aligned_size,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear, should_use_deepgemm_for_fp8_linear,
...@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp: ...@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp:
False, False,
self.act_quant_group_shape, self.act_quant_group_shape,
column_major_scales=True, column_major_scales=True,
tma_aligned_scales=True,
use_ue8m0=self.use_deep_gemm_e8m0, use_ue8m0=self.use_deep_gemm_e8m0,
) )
if self.is_deep_gemm_supported if self.is_deep_gemm_supported
...@@ -868,6 +870,7 @@ def per_token_group_quant_fp8( ...@@ -868,6 +870,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
column_major_scales: bool = False, column_major_scales: bool = False,
tma_aligned_scales: bool = False,
out_q: torch.Tensor | None = None, out_q: torch.Tensor | None = None,
use_ue8m0: bool | None = None, use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -878,9 +881,10 @@ def per_token_group_quant_fp8( ...@@ -878,9 +881,10 @@ def per_token_group_quant_fp8(
x: The input tensor with ndim >= 2. x: The input tensor with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now. is supported for now.
column_major_scales: Outputs scales in column major. column_major_scales: Outputs scales in column major.
tma_aligned_scales: Outputs scales in TMA-aligned layout.
out_q: Optional output tensor. If not provided, function will create. out_q: Optional output tensor. If not provided, function will create.
Returns: Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
...@@ -904,8 +908,24 @@ def per_token_group_quant_fp8( ...@@ -904,8 +908,24 @@ def per_token_group_quant_fp8(
# Allocate the scale tensor in either row- or column-major format. # Allocate the scale tensor in either row- or column-major format.
if column_major_scales: if column_major_scales:
shape = (x.shape[-1] // group_size,) + x.shape[:-1] if tma_aligned_scales:
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) m = x.shape[-2]
sf_k = x.shape[-1] // group_size
tma_aligned_m = get_tma_aligned_size(m, 4)
shape = x.shape[:-2] + (m, sf_k)
stride = (
(1, tma_aligned_m)
if x.dim() == 2
else (tma_aligned_m * sf_k, 1, tma_aligned_m)
)
x_s = torch.empty_strided(
shape, stride, device=x.device, dtype=torch.float32
)
else:
shape = x.shape[:-2] + (x.shape[-1] // group_size, x.shape[-2])
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(
-1, -2
)
else: else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,) shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32) x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
......
...@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int: ...@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int:
return cdiv(x, y) * y return cdiv(x, y) * y
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
def get_tma_aligned_size(x: int, element_size: int):
return _align(x, 16 // element_size)
DEFAULT_BLOCK_SIZE = [128, 128] DEFAULT_BLOCK_SIZE = [128, 128]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment