Commit ae9668a8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Implement Swizzle 32B (#566)

* [Feature] Add Quarter Bank Swizzle Layout and Update GEMM Layout Logic

- Introduced a new `makeQuarterBankSwizzleLayout` function for layout swizzling of 32 bytes.
- Updated `makeGemmABLayout` to include an `enable_padding` parameter, allowing for conditional layout selection between padded and quarter bank swizzle layouts.
- Adjusted layout inference in GEMM operations to utilize the new quarter bank swizzle layout when appropriate.
- Enhanced bulk copy operations to recognize and handle the new layout type, improving memory access patterns.

* lint fix

* [Refactor] Update GEMM Layout Functions and Inference Logic

- Removed the `enable_padding` parameter from `makeGemmABLayout` to simplify its signature.
- Introduced `makeGemmABLayoutHopper` for enhanced layout handling specific to Hopper architecture.
- Updated layout inference in GEMM operations to utilize the new `makeGemmABLayoutHopper` function, improving clarity and maintainability in layout selection.
- Adjusted related layout functions to ensure consistent behavior across different architectures.

* Update bulk_copy.cc

* Update __init__.py
parent ae386a7b
...@@ -315,6 +315,27 @@ PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) { ...@@ -315,6 +315,27 @@ PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) {
return 2 * xor4x4(i1, j1) + xor2x2(i0, j0); return 2 * xor4x4(i1, j1) + xor2x2(i0, j0);
} }
// Layout swizzling for 32 bytes
Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
int element_size) {
// Swizzle 1 bit
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
ICHECK(stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 2) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
PrimExpr ts = FloorDiv(i, 8);
PrimExpr s = FloorMod(i, 8);
PrimExpr tc = FloorDiv(FloorDiv(j, vector_size), 2);
PrimExpr c = FloorMod(FloorDiv(j, vector_size), 2);
PrimExpr vec = FloorMod(j, vector_size);
PrimExpr c_swizzle = xor2x2(c, FloorDiv(s, 4));
PrimExpr index = vec + (c_swizzle + s * 2) * vector_size;
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}
// Layout swizzling for 64 bytes
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) { Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) {
// Swizzle 2 bit // Swizzle 2 bit
Var i = InputPlaceholder(0); Var i = InputPlaceholder(0);
...@@ -333,6 +354,7 @@ Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) { ...@@ -333,6 +354,7 @@ Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size) {
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index}); return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
} }
// Layout swizzling for 128 bytes
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) { Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size) {
// Swizzle 3 bit // Swizzle 3 bit
Var i = InputPlaceholder(0); Var i = InputPlaceholder(0);
...@@ -552,6 +574,29 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, ...@@ -552,6 +574,29 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
} }
} }
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor) {
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kPack) { int kPack) {
int vector_size = 128 / element_size; int vector_size = 128 / element_size;
......
...@@ -161,6 +161,8 @@ Layout makeGemmLayoutLinear(int stride, int continuous); ...@@ -161,6 +161,8 @@ Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor); int element_size, int kfactor);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor); int kfactor);
...@@ -175,6 +177,8 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, ...@@ -175,6 +177,8 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
int element_size);
namespace attr { namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block // BlockAttr, Containing the layout for all the buffers in the block
......
...@@ -196,6 +196,11 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -196,6 +196,11 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
*stride, *continuous, *stride, *continuous,
shared_tensor->dtype.bits()))) { shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else if (StructuralEqual()(
shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()( } else if (StructuralEqual()(
shared_layout, shared_layout,
makeHalfBankSwizzleLayout(*stride, *continuous, makeHalfBankSwizzleLayout(*stride, *continuous,
...@@ -357,7 +362,11 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -357,7 +362,11 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
auto continuous = as_const_int(shared_layout->InputShape()[1]); auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr); ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, if (StructuralEqual()(shared_layout,
makeHalfBankSwizzleLayout(*stride, *continuous, makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) { dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
...@@ -457,4 +466,4 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) ...@@ -457,4 +466,4 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -347,9 +347,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -347,9 +347,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity = const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous; trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
results.Set(A, results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, mat_continuous, A->dtype.bits(),
A->dtype.bits(), trans_A ? 1 : 2)); trans_A ? 1 : 2));
} else { } else {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A); A->dtype.bits(), trans_A);
...@@ -361,8 +361,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -361,8 +361,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity = const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n; trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity, results.Set(B,
B->dtype.bits(), trans_B ? 2 : 1)); makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
} else { } else {
ICHECK(0) << "WGMMA only support B in shared."; ICHECK(0) << "WGMMA only support B in shared.";
} }
......
...@@ -35,7 +35,7 @@ class TimeoutException(Exception): ...@@ -35,7 +35,7 @@ class TimeoutException(Exception):
def timeout_handler(signum, frame): def timeout_handler(signum, frame):
raise TimeoutException() raise TimeoutException("Operation timed out")
def run_with_timeout(func, timeout, *args, **kwargs): def run_with_timeout(func, timeout, *args, **kwargs):
...@@ -43,6 +43,8 @@ def run_with_timeout(func, timeout, *args, **kwargs): ...@@ -43,6 +43,8 @@ def run_with_timeout(func, timeout, *args, **kwargs):
signal.alarm(timeout) signal.alarm(timeout)
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
except Exception as e:
raise e
finally: finally:
signal.alarm(0) signal.alarm(0)
return result return result
...@@ -101,7 +103,7 @@ class AutoTuner: ...@@ -101,7 +103,7 @@ class AutoTuner:
_kernel_parameters: Optional[Tuple[str, ...]] = None _kernel_parameters: Optional[Tuple[str, ...]] = None
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner" cache_dir: Path = Path(TILELANG_CACHE_DIR)
def __init__(self, fn: Callable, configs): def __init__(self, fn: Callable, configs):
self.fn = fn self.fn = fn
...@@ -350,7 +352,6 @@ class AutoTuner: ...@@ -350,7 +352,6 @@ class AutoTuner:
max_mismatched_ratio=max_mismatched_ratio) max_mismatched_ratio=max_mismatched_ratio)
latency = profiler.do_bench( latency = profiler.do_bench(
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
if self.ref_latency_cache is None and ref_prog is not None: if self.ref_latency_cache is None and ref_prog is not None:
self.ref_input_tensors = ref_input_tensors_supply() self.ref_input_tensors = ref_input_tensors_supply()
self.ref_latency_cache = profiler.do_bench( self.ref_latency_cache = profiler.do_bench(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment