Unverified Commit 6bae64f6 authored by Gongen-Ali's avatar Gongen-Ali Committed by GitHub
Browse files

[Enhancement] Add support for k_pack in gemm_mfma (#1344)

* add support for k_pack

* support benchmark on ROCm

* fix format
parent 4f844000
import argparse
import itertools
import torch
import logging
import tilelang
import tilelang.language as T
......@@ -99,6 +100,7 @@ def get_configs(args, kwargs):
block_K=[64, 128],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
k_pack=[1, 2],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
......@@ -125,6 +127,7 @@ def matmul(
block_K=None,
num_stages=None,
thread_num=None,
k_pack=None,
policy=None,
enable_rasteration=None,
):
......@@ -156,7 +159,7 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float8_e4m3"
dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3"
accum_dtype = "float"
@T.prim_func
......@@ -210,6 +213,7 @@ def matmul(
C_local,
transpose_B=True,
policy=policy,
k_pack=k_pack,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
......
......@@ -127,3 +127,41 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res;
}
__device__ fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0,
fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6,
fp8_e4_t y7) {
signed char x0_char = *reinterpret_cast<signed char *>(&x0);
signed char x1_char = *reinterpret_cast<signed char *>(&x1);
signed char x2_char = *reinterpret_cast<signed char *>(&x2);
signed char x3_char = *reinterpret_cast<signed char *>(&x3);
signed char x4_char = *reinterpret_cast<signed char *>(&x4);
signed char x5_char = *reinterpret_cast<signed char *>(&x5);
signed char x6_char = *reinterpret_cast<signed char *>(&x6);
signed char x7_char = *reinterpret_cast<signed char *>(&x7);
signed char y0_char = *reinterpret_cast<signed char *>(&y0);
signed char y1_char = *reinterpret_cast<signed char *>(&y1);
signed char y2_char = *reinterpret_cast<signed char *>(&y2);
signed char y3_char = *reinterpret_cast<signed char *>(&y3);
signed char y4_char = *reinterpret_cast<signed char *>(&y4);
signed char y5_char = *reinterpret_cast<signed char *>(&y5);
signed char y6_char = *reinterpret_cast<signed char *>(&y6);
signed char y7_char = *reinterpret_cast<signed char *>(&y7);
int a = (x3_char << 24) | (x2_char << 16) | (x1_char << 8) | x0_char;
int b = (x7_char << 24) | (x6_char << 16) | (x5_char << 8) | x4_char;
int c = (y3_char << 24) | (y2_char << 16) | (y1_char << 8) | y0_char;
int d = (y7_char << 24) | (y6_char << 16) | (y5_char << 8) | y4_char;
fp8_e4_8_t res_x;
res_x.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
res_x.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
fp8_e4_8_t res_y;
res_y.x = *reinterpret_cast<fp8_e4_4_t *>(&c);
res_y.y = *reinterpret_cast<fp8_e4_4_t *>(&d);
fp8_e4_16_t res;
res.x = res_x;
res.y = res_y;
return res;
}
\ No newline at end of file
......@@ -372,8 +372,8 @@ class MatrixCoreIntrinEmitter:
a_is_fragment = is_fragment(A_local_buf)
b_is_fragment = is_fragment(B_local_buf)
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * local_size_b if b_is_fragment else 0
a_local_stride: PrimExpr = k_inner * warp_rows * k_pack * local_size_a if a_is_fragment else 0
b_local_stride: PrimExpr = k_inner * warp_cols * k_pack * local_size_b if b_is_fragment else 0
@T.macro
def _warp_mfma(A_local_buf, B_local_buf, C_local_buf):
......@@ -543,7 +543,8 @@ class MatrixCoreIntrinEmitter:
return local_id
base_fragment = T.Fragment(
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
[micro_size_s, micro_size_r *
self.k_pack] if is_sr_axis_order else [micro_size_r * self.k_pack, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
......@@ -552,7 +553,7 @@ class MatrixCoreIntrinEmitter:
chunk = self.chunk
warp_s = warp_rows if matrix_is_a else warp_cols
warp_r = chunk // micro_size_r
warp_r = chunk // (micro_size_r * self.k_pack)
block_s = block_row_warps if matrix_is_a else block_col_warps
replicate = block_col_warps if matrix_is_a else block_row_warps
......
......@@ -28,6 +28,7 @@ class GemmMFMA(GemmBase):
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
k_pack=self.k_pack,
)
if self.is_gemm_ss():
......@@ -75,6 +76,7 @@ class GemmMFMA(GemmBase):
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
k_pack=self.k_pack,
)
in_dtype = self.in_dtype
......@@ -110,11 +112,11 @@ class GemmMFMA(GemmBase):
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
......@@ -145,12 +147,12 @@ class GemmMFMA(GemmBase):
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load A into fragment
mfma_emitter.ldmatrix_a(
......@@ -177,10 +179,10 @@ class GemmMFMA(GemmBase):
B_shared into local fragments, then issues Matrix Core mfma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Load B into fragment
mfma_emitter.ldmatrix_b(
......@@ -207,7 +209,7 @@ class GemmMFMA(GemmBase):
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
# Perform Matrix Multiplication
mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment