Unverified Commit 6bae64f6 authored by Gongen-Ali's avatar Gongen-Ali Committed by GitHub
Browse files

[Enhancement] Add support for k_pack in gemm_mfma (#1344)

* add support for k_pack

* support benchmark on ROCm

* fix format
parent 4f844000
import argparse import argparse
import itertools import itertools
import torch
import logging import logging
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
...@@ -99,6 +100,7 @@ def get_configs(args, kwargs): ...@@ -99,6 +100,7 @@ def get_configs(args, kwargs):
block_K=[64, 128], block_K=[64, 128],
num_stages=[0, 1, 2, 3], num_stages=[0, 1, 2, 3],
thread_num=[128, 256], thread_num=[128, 256],
k_pack=[1, 2],
policy=[T.GemmWarpPolicy.Square], policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False], enable_rasteration=[True, False],
) )
...@@ -125,6 +127,7 @@ def matmul( ...@@ -125,6 +127,7 @@ def matmul(
block_K=None, block_K=None,
num_stages=None, num_stages=None,
thread_num=None, thread_num=None,
k_pack=None,
policy=None, policy=None,
enable_rasteration=None, enable_rasteration=None,
): ):
...@@ -156,7 +159,7 @@ def matmul( ...@@ -156,7 +159,7 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float8_e4m3" dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
...@@ -210,6 +213,7 @@ def matmul( ...@@ -210,6 +213,7 @@ def matmul(
C_local, C_local,
transpose_B=True, transpose_B=True,
policy=policy, policy=policy,
k_pack=k_pack,
) )
# Write back the results from C_local to the global memory C # Write back the results from C_local to the global memory C
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
......
...@@ -127,3 +127,41 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, ...@@ -127,3 +127,41 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b); res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res; return res;
} }
__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0,
fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6,
fp8_e4_t y7) {
signed char x0_char = *reinterpret_cast<signed char *>(&x0);
signed char x1_char = *reinterpret_cast<signed char *>(&x1);
signed char x2_char = *reinterpret_cast<signed char *>(&x2);
signed char x3_char = *reinterpret_cast<signed char *>(&x3);
signed char x4_char = *reinterpret_cast<signed char *>(&x4);
signed char x5_char = *reinterpret_cast<signed char *>(&x5);
signed char x6_char = *reinterpret_cast<signed char *>(&x6);
signed char x7_char = *reinterpret_cast<signed char *>(&x7);
signed char y0_char = *reinterpret_cast<signed char *>(&y0);
signed char y1_char = *reinterpret_cast<signed char *>(&y1);
signed char y2_char = *reinterpret_cast<signed char *>(&y2);
signed char y3_char = *reinterpret_cast<signed char *>(&y3);
signed char y4_char = *reinterpret_cast<signed char *>(&y4);
signed char y5_char = *reinterpret_cast<signed char *>(&y5);
signed char y6_char = *reinterpret_cast<signed char *>(&y6);
signed char y7_char = *reinterpret_cast<signed char *>(&y7);
int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char;
int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char;
int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char;
int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char;
fp8_e4_8_t res_x;
res_x.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
res_x.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
fp8_e4_8_t res_y;
res_y.x = *reinterpret_cast<fp8_e4_4_t *>(&c);
res_y.y = *reinterpret_cast<fp8_e4_4_t *>(&d);
fp8_e4_16_t res;
res.x = res_x;
res.y = res_y;
return res;
}
\ No newline at end of file
...@@ -372,8 +372,8 @@ class MatrixCoreIntrinEmitter: ...@@ -372,8 +372,8 @@ class MatrixCoreIntrinEmitter:
a_is_fragment = is_fragment(A_local_buf) a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf) b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0 b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0
@T.macro @T.macro
def _warp_mfma(A_local_buf, B_local_buf, C_local_buf): def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
...@@ -543,7 +543,8 @@ class MatrixCoreIntrinEmitter: ...@@ -543,7 +543,8 @@ class MatrixCoreIntrinEmitter:
return local_id return local_id
base_fragment = T.Fragment( base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], [micro_size_s, micro_size_r *
self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
forward_thread_fn=forward_thread, forward_thread_fn=forward_thread,
forward_index_fn=forward_index, forward_index_fn=forward_index,
) )
...@@ -552,7 +553,7 @@ class MatrixCoreIntrinEmitter: ...@@ -552,7 +553,7 @@ class MatrixCoreIntrinEmitter:
chunk = self.chunk chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r warp_r = chunk // (micro_size_r * self.k_pack)
block_s = block_row_warps if matrix_is_a else block_col_warps block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps replicate = block_col_warps if matrix_is_a else block_row_warps
......
...@@ -28,6 +28,7 @@ class GemmMFMA(GemmBase): ...@@ -28,6 +28,7 @@ class GemmMFMA(GemmBase):
warp_row_tiles=warp_row_tiles, warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles, warp_col_tiles=warp_col_tiles,
chunk=self.chunk, chunk=self.chunk,
k_pack=self.k_pack,
) )
if self.is_gemm_ss(): if self.is_gemm_ss():
...@@ -75,6 +76,7 @@ class GemmMFMA(GemmBase): ...@@ -75,6 +76,7 @@ class GemmMFMA(GemmBase):
warp_col_tiles=warp_col_tiles, warp_col_tiles=warp_col_tiles,
chunk=self.chunk, chunk=self.chunk,
thread_var=thread_var, thread_var=thread_var,
k_pack=self.k_pack,
) )
in_dtype = self.in_dtype in_dtype = self.in_dtype
...@@ -110,11 +112,11 @@ class GemmMFMA(GemmBase): ...@@ -110,11 +112,11 @@ class GemmMFMA(GemmBase):
B_shared into local fragments, then issues Matrix Core mfma ops, B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local. accumulating into C_local.
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
if clear_accum: if clear_accum:
T.clear(C_buf) T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -145,12 +147,12 @@ class GemmMFMA(GemmBase): ...@@ -145,12 +147,12 @@ class GemmMFMA(GemmBase):
B_shared into local fragments, then issues Matrix Core mfma ops, B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local. accumulating into C_local.
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
if clear_accum: if clear_accum:
T.clear(C_buf) T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
...@@ -177,10 +179,10 @@ class GemmMFMA(GemmBase): ...@@ -177,10 +179,10 @@ class GemmMFMA(GemmBase):
B_shared into local fragments, then issues Matrix Core mfma ops, B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local. accumulating into C_local.
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
if clear_accum: if clear_accum:
T.clear(C_buf) T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load B into fragment # Load B into fragment
mfma_emitter.ldmatrix_b( mfma_emitter.ldmatrix_b(
...@@ -207,7 +209,7 @@ class GemmMFMA(GemmBase): ...@@ -207,7 +209,7 @@ class GemmMFMA(GemmBase):
accumulating into C_local. accumulating into C_local.
""" """
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_buf, B_buf, C_buf, ki) mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment