Unverified Commit 526de822 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[Kernel][Triton][AMD] Use block size heuristic for avg 2.8x speedup for int8 models (#11698)


Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 56fe4c29
...@@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor, ...@@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
block_size_m: int = 32, block_size_m: int = 32,
block_size_n: int = 32, block_size_n: int = 32,
block_size_k: int = 32) -> torch.Tensor: block_size_k: int = 32,
use_heuristic=True) -> torch.Tensor:
M, K = input.shape M, K = input.shape
N = weight.shape[1] N = weight.shape[1]
...@@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor, ...@@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor,
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
if use_heuristic:
is_small_N = N < 8192
next_power_of_2_M = max(32, triton.next_power_of_2(M))
if next_power_of_2_M <= 32:
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
elif next_power_of_2_M <= 64:
tile_shape = (64, 64, 256)
elif next_power_of_2_M <= 128:
tile_shape = (64, 128, 128)
else:
tile_shape = (128, 128, 128)
block_size_m, block_size_n, block_size_k = tile_shape
block_size_sa = 1 if has_scalar(scale_a) else block_size_m block_size_sa = 1 if has_scalar(scale_a) else block_size_m
block_size_sb = 1 if has_scalar(scale_b) else block_size_n block_size_sb = 1 if has_scalar(scale_b) else block_size_n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment