Commit 8f4628e0 authored by qisan's avatar qisan
Browse files

[Bugfix] Using a new data layout and the performance of NN gemm exceeds rocblas

parent 9a640856
......@@ -10,6 +10,7 @@ from tilelang import disable_cache
disable_cache()
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
......@@ -186,5 +187,3 @@ def main():
if __name__ == "__main__":
main()
......@@ -169,7 +169,7 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, true);
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout;
}
......@@ -747,7 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
// element_size);
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
......
......@@ -4,8 +4,8 @@
*/
#include "gemm.h"
#include <fstream>
#include "builtin.h"
#include <fstream>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
......@@ -828,8 +828,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK(C.scope() == "local.fragment")
<< "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<< C.scope();
if (TargetIsDCU(T.target))
{
if (TargetIsDCU(T.target)) {
auto fragment =
makeGemmFragmentCDCU(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......
......@@ -249,7 +249,6 @@ TVM_REGISTER_OP("tir.hip.__shfl")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_REGISTER_OP("tir.hip.__shfl_sync")
.set_num_inputs(4)
.add_argument("mask", "Expr", "The thread mask.")
......
......@@ -5,7 +5,6 @@
#include "utils.h"
namespace tvm {
namespace tl {
......
#pragma once
#include "core.hpp"
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
......@@ -106,22 +105,19 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}
template <typename T>
struct is_half_type : std::false_type {};
template <typename T> struct is_half_type : std::false_type {};
template <>
struct is_half_type<__half> : std::true_type {};
template <> struct is_half_type<__half> : std::true_type {};
template <>
struct is_half_type<half_t> : std::true_type {};
template <> struct is_half_type<half_t> : std::true_type {};
template <typename T>
inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value;
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1* address, T2 val) {
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
if constexpr (is_half_v<T1>) {
__half* addr = reinterpret_cast<__half*>(address);
__half *addr = reinterpret_cast<__half *>(address);
__half hval = __float2half(static_cast<float>(val));
atomicAdd(addr, hval);
} else {
......@@ -129,16 +125,14 @@ TL_DEVICE void AtomicAdd(T1* address, T2 val) {
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 &ref, T2 val) {
template <typename T1, typename T2> TL_DEVICE void AtomicAdd(T1 &ref, T2 val) {
AtomicAdd(&ref, val);
}
template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
return atomicAdd(&ref, static_cast<T1>(val));
}
template <typename T>
TL_DEVICE void AtomicAddx4(T* ref, const T val[4]) {
template <typename T> TL_DEVICE void AtomicAddx4(T *ref, const T val[4]) {
atomicAdd(&ref[0], val[0]);
atomicAdd(&ref[1], val[1]);
atomicAdd(&ref[2], val[2]);
......
......@@ -108,4 +108,3 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
}
} // namespace tl
......@@ -25,82 +25,53 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
namespace ck_tile{
namespace ck_tile {
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x)
{
return x;
}
template <typename T> CK_TILE_HOST_DEVICE constexpr T max(T x) { return x; }
template <typename T>
CK_TILE_HOST constexpr T max(T x, T y)
{
template <typename T> CK_TILE_HOST constexpr T max(T x, T y) {
return x > y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T max(T x, T y)
{
template <typename T> CK_TILE_DEVICE constexpr T max(T x, T y) {
return x > y ? x : y;
}
template <>
CK_TILE_DEVICE float max(float x, float y)
{
template <> CK_TILE_DEVICE float max(float x, float y) {
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
}
template <>
CK_TILE_DEVICE double max(double x, double y)
{
template <> CK_TILE_DEVICE double max(double x, double y) {
return __builtin_fmax(x, y); // maybe still v_max3_f32
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys)
{
CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys) {
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x)
{
return x;
}
template <typename T> CK_TILE_HOST_DEVICE constexpr T min(T x) { return x; }
template <typename T>
CK_TILE_HOST constexpr T min(T x, T y)
{
template <typename T> CK_TILE_HOST constexpr T min(T x, T y) {
return x < y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T min(T x, T y)
{
template <typename T> CK_TILE_DEVICE constexpr T min(T x, T y) {
return x < y ? x : y;
}
template <>
CK_TILE_DEVICE float min(float x, float y)
{
template <> CK_TILE_DEVICE float min(float x, float y) {
return __builtin_fminf(x, y);
}
template <>
CK_TILE_DEVICE double min(double x, double y)
{
template <> CK_TILE_DEVICE double min(double x, double y) {
return __builtin_fmin(x, y);
}
template <typename X, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys)
{
CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys) {
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
}
} // namespace ck_tile
......@@ -189,4 +189,3 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
(int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, safe_buf_name,
index, var);
}
......@@ -69,7 +69,7 @@ template <int M, int N, int K, int num_warp_n, int num_warp_m, bool TransposeA,
typename B_type, typename C_type, typename AccDataType = float>
class GemmTensorOp {
public:
//static_assert(!clear_accum, "clear_accum=true is not supported yet");
// static_assert(!clear_accum, "clear_accum=true is not supported yet");
static constexpr int micro_size_x = 16;
static constexpr int micro_size_y = 16;
......@@ -156,8 +156,8 @@ public:
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_m = warp_id / block_col_warps;
auto warp_n = warp_id % block_col_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
......@@ -165,8 +165,8 @@ public:
auto tx = lane_id;
auto alane_id = lane_id;
auto blane_id = ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
auto blane_id =
((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
......@@ -194,7 +194,6 @@ public:
B_local[i * kPack * local_size_b + local_id] =
B_shared[make_swizzle_layout<last_dim_b, sizeof(B_type)>(
r + row, l + col)];
}
}
}
......@@ -237,8 +236,8 @@ public:
C_type *C_local) {
auto tid = threadIdx.x;
auto warp_id = tid / warp_size;
auto warp_n = warp_id / block_row_warps;
auto warp_m = warp_id % block_row_warps;
auto warp_m = warp_id / block_col_warps;
auto warp_n = warp_id % block_col_warps;
auto warp_row_tiles = warp_rows * micro_size_x;
auto warp_col_tiles = warp_cols * micro_size_y;
......@@ -246,7 +245,8 @@ public:
auto tx = lane_id;
auto alane_id = lane_id;
auto blane_id = ((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
auto blane_id =
((lane_id & 15) >> 2) + ((lane_id & 3) << 2) + ((lane_id >> 4) << 4);
constexpr auto local_size_a = (micro_size_x * micro_size_k) / warp_size;
constexpr auto local_size_b = (micro_size_y * micro_size_k) / warp_size;
......@@ -321,4 +321,3 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
} // namespace tl
......@@ -72,4 +72,3 @@ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
return res;
}
......@@ -22,14 +22,11 @@ struct MinOp {
}
};
// Detect half types
template <typename T>
struct is_half_type : std::false_type {};
template <typename T> struct is_half_type : std::false_type {};
template <>
struct is_half_type<__half> : std::true_type {};
template <> struct is_half_type<__half> : std::true_type {};
template <>
struct is_half_type<_Float16> : std::true_type {};
template <> struct is_half_type<_Float16> : std::true_type {};
template <typename T>
inline constexpr bool is_half_v = is_half_type<std::decay_t<T>>::value;
......@@ -56,7 +53,10 @@ struct AllReduce {
if constexpr (std::is_same_v<std::decay_t<T>, __half>) {
x_raw = __half_as_ushort(x);
} else { // _Float16
union { _Float16 f; unsigned short s; } u;
union {
_Float16 f;
unsigned short s;
} u;
u.f = x;
x_raw = u.s;
}
......@@ -67,7 +67,10 @@ struct AllReduce {
if constexpr (std::is_same_v<std::decay_t<T>, __half>) {
shuffled_x = __ushort_as_half(shuffled_raw);
} else { // _Float16
union { unsigned short s; _Float16 f; } u;
union {
unsigned short s;
_Float16 f;
} u;
u.s = shuffled_raw;
shuffled_x = u.f;
}
......@@ -116,7 +119,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
......@@ -142,7 +145,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off)
......@@ -164,4 +167,3 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
}
};
} // namespace tl
......@@ -43,4 +43,3 @@ template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
}
} // namespace tl
......@@ -12,6 +12,7 @@ from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
......@@ -63,7 +64,7 @@ def tl_matmul(
chunk = 32 * k_pack
shared_scope = "shared"
cache_write_shared = False
# cache_write_shared = False
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
......
from __future__ import annotations
from tilelang import tvm as tvm
import tilelang.language as T
from typing import Tuple
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
from typing import Optional
from .utils import (
mfma_store_index_map,)
lift = convert
class MatrixCoreIntrinEmitter(object):
class MatrixCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
......@@ -51,9 +50,9 @@ class MatrixCoreIntrinEmitter(object):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
b_preshuffle: Optional[bool] = False,
k_pack: int | None = None,
is_m_first: bool | None = False,
b_preshuffle: bool | None = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
......@@ -119,7 +118,7 @@ class MatrixCoreIntrinEmitter(object):
"float16": "f16",
"float32": "f32",
"int8": "i8",
"bfloat16" : "bf16"
"bfloat16": "bf16"
}[in_dtype]
self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
......@@ -129,15 +128,15 @@ class MatrixCoreIntrinEmitter(object):
self.micro_size_y = n_dim
self.micro_size_k = k_dim
def _initialize_k_pack(self, k_pack: Optional[int] = None):
def _initialize_k_pack(self, k_pack: int | None = None):
if k_pack is not None:
self.k_pack = k_pack
def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False):
def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False):
if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle
......@@ -197,7 +196,7 @@ class MatrixCoreIntrinEmitter(object):
def extract_thread_binding(self,
thread_id,
is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
......@@ -290,7 +289,9 @@ class MatrixCoreIntrinEmitter(object):
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
row, col = T.meta_var(
reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16,
local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
......@@ -301,7 +302,9 @@ class MatrixCoreIntrinEmitter(object):
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
row, col = T.meta_var(
reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16,
local_id))
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
......@@ -412,10 +415,10 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
a_preshuffle: Optional[bool] = False,
b_preshuffle: Optional[bool] = False,
k_pack: int | None = None,
is_m_first: bool | None = False,
a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False,
):
self.a_dtype = a_dtype
......@@ -579,7 +582,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id))
row, col = T.meta_var(
reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4),
local_id))
l, r = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
......@@ -589,7 +594,9 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id))
row, col = T.meta_var(
reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4),
local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
......@@ -600,4 +607,3 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding,
rk) if is_global else _warp_ldmatrix_b_shared(
B_local_buf, B_buf, ki, thread_binding, rk)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment