Unverified Commit 8fbe1b3a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)

* Add kernel selection option for GEMM v1 in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

* bug fix

* Add kernel selection option for GEMM in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.

* Refactor GEMM macro generator to use BufferRegion instead of Buffer

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.

* Refactor GEMM functions to utilize BufferRegion for improved memory handling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.

* Refactor GEMM code for improved readability and consistency

- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.

* Update GEMM correctness evaluation and macro generator for improved functionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.

* Refactor GEMM and intrinsic files for improved clarity and functionality

- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.

* Enhance GEMM and MMA functionality for FP64 support

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.

* lint fix

* Refactor GEMM correctness evaluation and shared memory alignment handling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.

* Enhance GEMM and MMA implementations with region-based memory handling

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.

* Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.

* Enhance OnArrayDeclaration method to handle repeated buffer declarations

- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.

* Add abbreviation for bfloat16 data type in mfma_macro_generator.py

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.

* Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.

* lint fix

* Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.

* Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.

* Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity.

* Refactor attention sink examples to simplify index calculations

- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.

* Refactor attention sink examples to streamline index calculations

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.

* lint fix

* bugfix

* Refactor reduce operation handling in CUDA and Python

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.

* Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs

* Enhance unit loop handling by refining annotation checks

- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.

* clean clode
parent 2b1f5990
......@@ -105,3 +105,6 @@ cmake-build-*/
# Git version for sdist
.git_commit.txt
# pre-commit cache
.pre-commit-cache/*
......@@ -65,9 +65,50 @@ else()
endif()
# Configs
set(USE_CUDA OFF)
set(USE_ROCM OFF)
set(USE_METAL OFF)
set(TILELANG_BACKENDS CUDA ROCM METAL)
set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)")
set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)")
set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend")
# TVM's config.cmake redefines USE_* options later, so we cache the user's choice
# (including explicit -DUSE_XXX arguments) before we include TVM and restore it
# afterwards.
macro(tilelang_define_backend_option BACKEND)
set(_backend_var "USE_${BACKEND}")
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
set(_user_override OFF)
if(DEFINED ${_user_override_var})
set(_user_override "${${_user_override_var}}")
endif()
if(DEFINED CACHE{${_backend_var}})
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
if(_cache_type STREQUAL "UNINITIALIZED")
set(_user_override ON)
endif()
endif()
set(_default OFF)
if(DEFINED ${_backend_var})
set(_default "${${_backend_var}}")
endif()
option(${_backend_var} "${_doc}" "${_default}")
# Remember if the user explicitly set this option so that later logic
# won't auto-toggle backends they configured on the command line.
set(${_user_override_var} ${_user_override} CACHE INTERNAL
"User explicitly set ${_backend_var} during configuration" FORCE)
set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}")
endmacro()
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
tilelang_define_backend_option(${BACKEND})
endforeach()
set(PREBUILD_CYTHON ON)
# Configs end
......@@ -78,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake)
else()
message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.")
endif()
# Re-apply TileLang's preferred backend settings after TVM's config may have
# overridden the USE_* cache entries.
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
set(_backend_var "USE_${BACKEND}")
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE)
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}})
endforeach()
# Include directories for TileLang
set(TILE_LANG_INCLUDES ${TVM_INCLUDES})
......@@ -95,23 +144,35 @@ file(GLOB TILE_LANG_SRCS
src/target/intrin_rule*.cc
)
# Backend-specific checks and configs
if($ENV{USE_METAL})
set(USE_METAL ON)
elseif(APPLE)
message(STATUS "Enable Metal support by default.")
set(USE_METAL ON)
elseif($ENV{USE_ROCM})
set(USE_ROCM ON)
else()
if($ENV{USE_CUDA})
set(USE_CUDA ON)
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
# Build CPU-only when we explicitly disable CUDA
set(USE_CUDA OFF)
# Track if the user explicitly selected a backend via cache options.
set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
set(_backend_var "USE_${BACKEND}")
set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
if(${_backend_var} OR ${_override_var})
set(TILELANG_BACKEND_USER_SELECTED ON)
endif()
endforeach()
# Only auto-select a backend when the user didn't specify one explicitly.
if(NOT TILELANG_BACKEND_USER_SELECTED)
if($ENV{USE_METAL})
set(USE_METAL ON)
elseif(APPLE)
message(STATUS "Enable Metal support by default.")
set(USE_METAL ON)
elseif($ENV{USE_ROCM})
set(USE_ROCM ON)
else()
message(STATUS "Enable CUDA support by default.")
set(USE_CUDA ON)
if($ENV{USE_CUDA})
set(USE_CUDA ON)
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
# Build CPU-only when we explicitly disable CUDA
set(USE_CUDA OFF)
else()
message(STATUS "Enable CUDA support by default.")
set(USE_CUDA ON)
endif()
endif()
endif()
......@@ -125,7 +186,7 @@ if(USE_METAL)
elseif(USE_ROCM)
set(CMAKE_HIP_STANDARD 17)
include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake)
find_rocm($ENV{USE_ROCM})
find_rocm(${USE_ROCM})
add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1)
file(GLOB TILE_LANG_HIP_SRCS
......
......@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
start = T.max(0,
(bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
......@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......
......@@ -172,14 +172,11 @@ def flashattn(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
for k in T.Pipelined(
start[0],
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
......
......@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
start = T.max(0,
(bx * block_M - window_size) // block_N) if window_size is not None else 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i
......@@ -267,14 +264,10 @@ def flashattn_bwd(
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32')
if window_size is not None:
loop_ed[0] = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
loop_ed = T.min(
T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......
......@@ -162,13 +162,10 @@ def flashattn(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
for k in T.Pipelined(start[0], end, num_stages=num_stages):
for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
......
......@@ -165,14 +165,11 @@ def flashattn(
end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32')
if window_size is not None:
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
start = T.max(0, (bx * block_M + past_len - window_size) //
block_N) if window_size is not None else 0
for k in T.Pipelined(
start[0],
start,
end,
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
......
# ruff: noqa
import tilelang.testing
from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd
import topk_selector
import fp8_lighting_indexer
import sparse_mla_fwd
import sparse_mla_fwd_pipelined
import sparse_mla_bwd
def test_example_topk_selector():
test_topk_selector()
topk_selector.test_topk_selector()
def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd():
# small shapes for testing
test_sparse_mla_fwd(
sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
......@@ -28,14 +28,14 @@ def test_example_sparse_mla_fwd():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd(
sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
......
......@@ -80,7 +80,6 @@ def tl_fused_chunk_fwd_kernel(
T.atomic_add(
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared)
#TODO: consider using vectorized atomic add or tma reduce for sm90
# Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
......@@ -91,6 +90,7 @@ def tl_fused_chunk_fwd_kernel(
def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o)
return o, h
......
......@@ -51,13 +51,6 @@ def chunk_retention_fwd_kernel(
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h)
T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10)
for i in T.Pipelined(0, NT):
......
import tilelang
import tilelang.language as T
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
......@@ -52,11 +54,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def main(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128
block_N = 128
block_K = 64
jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
print(jit_kernel.get_kernel_source())
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
......
......@@ -46,8 +46,7 @@ def matmul(
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
......@@ -103,9 +102,11 @@ def run_gemm(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul(
M,
N,
......@@ -189,9 +190,11 @@ def run_gemm_rs(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_rs(
M,
N,
......@@ -273,9 +276,11 @@ def run_gemm_sr(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_sr(
M,
N,
......@@ -361,9 +366,11 @@ def run_gemm_rr(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_rr(
M,
N,
......@@ -429,51 +436,51 @@ def _ensure_torch_dtypes(*dtype_names):
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rs_true_false(m, n, k):
run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rs_true_true(m, n, k):
run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_sr_false_false(m, n, k):
run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_true_false(m, n, k):
run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_true_true(m, n, k):
run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128)
run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_rr_false_false(m, n, k):
run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_true_false(m, n, k):
run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_true_true(m, n, k):
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
TRANS_CASES = [
......@@ -516,8 +523,6 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
m,
n,
k,
2,
128,
)
......@@ -537,8 +542,6 @@ def test_gemm_false_false(m, n, k):
m,
n,
k,
2,
128,
)
......@@ -558,8 +561,6 @@ def test_gemm_true_false(m, n, k):
m,
n,
k,
2,
128,
)
......@@ -579,8 +580,6 @@ def test_gemm_true_true(m, n, k):
m,
n,
k,
2,
128,
)
......@@ -724,3 +723,13 @@ if __name__ == "__main__":
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
......@@ -211,7 +211,7 @@ def run_gemm_rs(
M_VALUES = [64, 128]
N_VALUES = [16, 32, 64, 128]
N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([
pytest.param(
......
This diff is collapsed.
......@@ -40,7 +40,7 @@ public:
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
}
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
Target target,
GemmInst gemm_inst) const;
......@@ -84,47 +84,47 @@ public:
class GemmNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
bool checkWgmma() const;
tir::Buffer a_, b_, c_;
// BufferRegion for A, B and C
BufferRegion aRegion_, bRegion_, cRegion_;
bool transA_, transB_;
int m_, n_, k_;
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
PrimExpr mbarptr;
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> C_coords;
mutable GemmWarpPolicy policy;
int kPack_ = 1;
int wgWait_ = 0;
PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmNode>()
.def_ro("A", &GemmNode::A)
.def_ro("B", &GemmNode::B)
.def_ro("C", &GemmNode::C)
.def_ro("Aptr", &GemmNode::Aptr)
.def_ro("Bptr", &GemmNode::Bptr)
.def_ro("Cptr", &GemmNode::Cptr)
.def_ro("trans_A", &GemmNode::trans_A)
.def_ro("trans_B", &GemmNode::trans_B)
.def_ro("M", &GemmNode::M)
.def_ro("N", &GemmNode::N)
.def_ro("K", &GemmNode::K)
.def_ro("stride_A", &GemmNode::stride_A)
.def_ro("stride_B", &GemmNode::stride_B)
.def_ro("offset_A", &GemmNode::offset_A)
.def_ro("offset_B", &GemmNode::offset_B)
.def_ro("clear_accum", &GemmNode::clear_accum)
.def_ro("kPack", &GemmNode::kPack)
.def_ro("wg_wait", &GemmNode::wg_wait)
.def_ro("policy", &GemmNode::policy);
.def_ro("a", &GemmNode::a_)
.def_ro("b", &GemmNode::b_)
.def_ro("c", &GemmNode::c_)
.def_ro("aRegion", &GemmNode::aRegion_)
.def_ro("bRegion", &GemmNode::bRegion_)
.def_ro("cRegion", &GemmNode::cRegion_)
.def_ro("transA", &GemmNode::transA_)
.def_ro("transB", &GemmNode::transB_)
.def_ro("m", &GemmNode::m_)
.def_ro("n", &GemmNode::n_)
.def_ro("k", &GemmNode::k_)
.def_ro("strideA", &GemmNode::strideA_)
.def_ro("strideB", &GemmNode::strideB_)
.def_ro("offsetA", &GemmNode::offsetA_)
.def_ro("offsetB", &GemmNode::offsetB_)
.def_ro("clearAccum", &GemmNode::clearAccum_)
.def_ro("kPack", &GemmNode::kPack_)
.def_ro("wgWait", &GemmNode::wgWait_)
.def_ro("policy", &GemmNode::policy_);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
......@@ -134,9 +134,9 @@ public:
TileOperator Clone() const;
private:
GemmInst GetGemmInst(int block_size, Target target) const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
GemmInst getGemmInst(int block_size, Target target) const;
bool allowTcgen5Mma(Target target) const;
bool allowWgmma(int block_size, Target target) const;
mutable bool completed_ = false;
};
......
......@@ -11,16 +11,102 @@
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
using namespace tir;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
......@@ -51,45 +137,42 @@ using namespace tir;
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<PrimExpr>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[3].as<Bool>().value();
node->transB_ = args[4].as<Bool>().value();
node->m_ = args[5].as<IntImm>().value()->value;
node->n_ = args[6].as<IntImm>().value()->value;
node->k_ = args[7].as<IntImm>().value()->value;
node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clearAccum_ = args[9].as<PrimExpr>().value();
node->strideA_ = args[10].as<IntImm>().value()->value;
node->strideB_ = args[11].as<IntImm>().value()->value;
node->offsetA_ = args[12].as<IntImm>().value()->value;
node->offsetB_ = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
node->kPack_ = args[14].as<IntImm>().value()->value;
if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
if (args.size() > 16) {
node->mbarptr = args[16];
} else {
node->mbarptr = IntImm(DataType::UInt(32), 0);
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
if (args.size() > 18) {
node->C_coords = Array<PrimExpr>({args[17], args[18]});
} else if (args.size() > 17) {
node->C_coords = Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else {
node->C_coords = Array<PrimExpr>(
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
node->mbar_ = std::nullopt;
}
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
data_ = std::move(node);
}
......@@ -106,28 +189,28 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op);
}
bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
bool GemmPyNode::allowTcgen5Mma(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
a_.scope() == "shared.tmem") &&
(b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
c_.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
}
bool GemmPyNode::AllowWGMMA(int block_size, Target target) const {
bool GemmPyNode::allowWgmma(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
checkWgmma();
}
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = allowTcgen5Mma(target);
bool allow_wgmma = allowWgmma(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
......@@ -175,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmPyNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
bool GemmPyNode::checkWgmma() const {
if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
return false;
}
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16))
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
} else if (c_->dtype == DataType::Float(32)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype == DataType::BFloat(16) &&
b_->dtype == DataType::BFloat(16))
return k_ % 16 == 0;
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
} else if (c_->dtype == DataType::Int(32)) {
if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else {
......@@ -256,10 +340,10 @@ static int GetArchInt(Target target) {
Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
GemmInst gemm_inst = getGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func =
......@@ -302,6 +386,14 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
// Bind all fragment layouts with the provided thread range
for (auto kv : results) {
const Buffer &buf = kv.first;
const Layout &layout = kv.second;
if (auto frag = layout.as<Fragment>()) {
results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds));
}
}
} else {
LOG(FATAL) << "No infer layout function found for gemm_py";
}
......@@ -321,7 +413,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target);
return gemm_py->getGemmInst(block_size, target);
});
}
......
......@@ -18,51 +18,52 @@ using namespace tir;
class GemmPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
PrimExpr mbarptr;
Array<PrimExpr> C_coords;
bool checkWgmma() const;
bool allowTcgen5Mma(Target target) const;
bool allowWgmma(int block_size, Target target) const;
tir::Buffer a_, b_, c_;
// BufferRegion for A, B and C
BufferRegion aRegion_, bRegion_, cRegion_;
bool transA_, transB_;
int m_, n_, k_;
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
mutable GemmWarpPolicy policy;
int kPack_ = 1;
int wgWait_ = 0;
mutable GemmWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmPyNode>()
.def_ro("A", &GemmPyNode::A)
.def_ro("B", &GemmPyNode::B)
.def_ro("C", &GemmPyNode::C)
.def_ro("Aptr", &GemmPyNode::Aptr)
.def_ro("Bptr", &GemmPyNode::Bptr)
.def_ro("Cptr", &GemmPyNode::Cptr)
.def_ro("trans_A", &GemmPyNode::trans_A)
.def_ro("trans_B", &GemmPyNode::trans_B)
.def_ro("M", &GemmPyNode::M)
.def_ro("N", &GemmPyNode::N)
.def_ro("K", &GemmPyNode::K)
.def_ro("stride_A", &GemmPyNode::stride_A)
.def_ro("stride_B", &GemmPyNode::stride_B)
.def_ro("offset_A", &GemmPyNode::offset_A)
.def_ro("offset_B", &GemmPyNode::offset_B)
.def_ro("clear_accum", &GemmPyNode::clear_accum)
.def_ro("mbarptr", &GemmPyNode::mbarptr)
.def_ro("C_coords", &GemmPyNode::C_coords)
.def_ro("kPack", &GemmPyNode::kPack)
.def_ro("wg_wait", &GemmPyNode::wg_wait)
.def_ro("policy", &GemmPyNode::policy);
.def_ro("a", &GemmPyNode::a_)
.def_ro("b", &GemmPyNode::b_)
.def_ro("c", &GemmPyNode::c_)
.def_ro("aRegion", &GemmPyNode::aRegion_)
.def_ro("bRegion", &GemmPyNode::bRegion_)
.def_ro("cRegion", &GemmPyNode::cRegion_)
.def_ro("transA", &GemmPyNode::transA_)
.def_ro("transB", &GemmPyNode::transB_)
.def_ro("m", &GemmPyNode::m_)
.def_ro("n", &GemmPyNode::n_)
.def_ro("k", &GemmPyNode::k_)
.def_ro("strideA", &GemmPyNode::strideA_)
.def_ro("strideB", &GemmPyNode::strideB_)
.def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarPtr", &GemmPyNode::mbarPtr_)
.def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wgWait", &GemmPyNode::wgWait_)
.def_ro("policy", &GemmPyNode::policy_);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
......@@ -72,7 +73,7 @@ public:
TileOperator Clone() const;
// Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const;
GemmInst getGemmInst(int block_size, Target target) const;
private:
mutable bool completed_ = false;
......
......@@ -18,14 +18,14 @@
namespace tvm {
namespace tl {
std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
int block_size,
Target target,
bool use_wgmma,
int bits) const {
int num_warps = block_size / TargetGetWarpSize(target);
auto [m_warp, n_warp] = GemmWarpPolicyNode::ComputeWarpPartition(
auto [m_warp, n_warp] = GemmWarpPolicyNode::computeWarpPartition(
M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA);
// Special handling for gemm_sp when the tiling size is not a multiple
......@@ -85,25 +85,25 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
*/
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])];
node->E = vmap[GetVarFromAccessPtr(args[1])];
node->B = vmap[GetVarFromAccessPtr(args[2])];
node->C = vmap[GetVarFromAccessPtr(args[3])];
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->M = args[6].as<IntImm>().value()->value;
node->N = args[7].as<IntImm>().value()->value;
node->K = args[8].as<IntImm>().value()->value;
node->policy = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value();
node->a_ = vmap[GetVarFromAccessPtr(args[0])];
node->e_ = vmap[GetVarFromAccessPtr(args[1])];
node->b_ = vmap[GetVarFromAccessPtr(args[2])];
node->c_ = vmap[GetVarFromAccessPtr(args[3])];
node->transA_ = args[4].as<Bool>().value();
node->transB_ = args[5].as<Bool>().value();
node->m_ = args[6].as<IntImm>().value()->value;
node->n_ = args[7].as<IntImm>().value()->value;
node->k_ = args[8].as<IntImm>().value()->value;
node->policy_ = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
node->clearAccum_ = args[10].as<Bool>().value();
if (args.size() > 11) {
node->kPack = args[11].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
node->kPack_ = args[11].as<IntImm>().value()->value;
if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 12) {
node->wg_wait = args[12].as<IntImm>().value()->value;
node->wgWait_ = args[12].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
......@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
bool maybe_wgmma = TargetIsHopper(T.target) && (this->m_ >= 64) &&
(block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
auto [warp_m, warp_n] = policy_->computeWarpPartition(
m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits());
std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss";
ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
(B.scope() == "shared" || B.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received " << A.scope()
<< " and " << B.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
ICHECK((a_.scope() == "shared" || a_.scope() == "shared.dyn") &&
(b_.scope() == "shared" || b_.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received "
<< a_.scope() << " and " << b_.scope();
ICHECK((e_.scope() == "shared" || e_.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implementation, found "
<< E.scope();
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
<< e_.scope();
ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
ss << transA_ << ", " << transB_;
ss << ", " << clearAccum_;
if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
if (wgWait_ != 0) {
ss << ", " << wgWait_;
}
ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
auto C_buffer = T.buffer_remap[C];
auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E;
auto A_buffer = T.buffer_remap.count(a_) ? T.buffer_remap[a_] : a_;
auto B_buffer = T.buffer_remap.count(b_) ? T.buffer_remap[b_] : b_;
auto C_buffer = T.buffer_remap[c_];
auto E_buffer = T.buffer_remap.count(e_) ? T.buffer_remap[e_] : e_;
auto new_call =
Call(DataType::Handle(), tl::tl_gemm_sp(),
......@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
if (completed_)
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
ICHECK(c_.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
if (TargetIsHopper(T.target)) {
const int warp_size = 32;
constexpr int wgmma_m = 16 * 4;
bool maybe_wgmma =
(this->M >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
auto fragment =
maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
(this->m_ >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy_->computeWarpPartition(
m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits());
auto fragment = maybe_wgmma
? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
n_ / warp_n, c_->dtype.bits())
: makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(a_, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, a_->dtype.bits(),
transA_ ? 1 : 2));
} else {
ICHECK(false) << "Not implemented";
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B,
transB_ ? mat_continuous : mat_continuous / warp_n;
results.Set(b_,
makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
b_->dtype.bits(), transB_ ? 2 : 1));
} else {
ICHECK(false) << "WGMMA only support B in shared.";
}
} else if (TargetIsAmpere(T.target)) {
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, false, A->dtype.bits());
auto fragment =
makeGemmSparseFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
auto [warp_m, warp_n] = policy_->computeWarpPartition(
m_, n_, block_size, T.target, false, a_->dtype.bits());
auto fragment = makeGemmSparseFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
A->dtype.bits()));
} else if (A.scope() == "local.fragment") {
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(a_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
a_->dtype.bits()));
} else if (a_.scope() == "local.fragment") {
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// A->dtype.bits(), trans_A);
// results.Set(A, fragment->BindThreadRange(thread_range));
......@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
B->dtype.bits()));
} else if (B.scope() == "local.fragment") {
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
results.Set(b_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
b_->dtype.bits()));
} else if (b_.scope() == "local.fragment") {
// auto fragment =
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// results.Set(B, fragment->BindThreadRange(thread_range));
......
......@@ -18,7 +18,7 @@ using namespace tir;
class GemmSPWarpPolicyNode : public GemmWarpPolicyNode {
public:
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma,
int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
......@@ -53,16 +53,16 @@ public:
class GemmSPNode : public TileOperatorNode {
public:
tir::Buffer A, B, C, E;
bool trans_A, trans_B;
int M, N, K;
bool clear_accum = false;
tir::Buffer a_, b_, c_, e_;
bool transA_, transB_;
int m_, n_, k_;
bool clearAccum_ = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
int kPack_ = 1;
int wgWait_ = 0;
mutable GemmSPWarpPolicy policy;
mutable GemmSPWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
......@@ -74,19 +74,19 @@ public:
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy)
.def_ro("A", &GemmSPNode::A)
.def_ro("B", &GemmSPNode::B)
.def_ro("C", &GemmSPNode::C)
.def_ro("E", &GemmSPNode::E)
.def_ro("trans_A", &GemmSPNode::trans_A)
.def_ro("trans_B", &GemmSPNode::trans_B)
.def_ro("M", &GemmSPNode::M)
.def_ro("N", &GemmSPNode::N)
.def_ro("K", &GemmSPNode::K)
.def_ro("clear_accum", &GemmSPNode::clear_accum)
.def_ro("kPack", &GemmSPNode::kPack)
.def_ro("wg_wait", &GemmSPNode::wg_wait);
.def_ro("policy", &GemmSPNode::policy_)
.def_ro("a", &GemmSPNode::a_)
.def_ro("b", &GemmSPNode::b_)
.def_ro("c", &GemmSPNode::c_)
.def_ro("e", &GemmSPNode::e_)
.def_ro("transA", &GemmSPNode::transA_)
.def_ro("transB", &GemmSPNode::transB_)
.def_ro("m", &GemmSPNode::m_)
.def_ro("n", &GemmSPNode::n_)
.def_ro("k", &GemmSPNode::k_)
.def_ro("clearAccum", &GemmSPNode::clearAccum_)
.def_ro("kPack", &GemmSPNode::kPack_)
.def_ro("wgWait", &GemmSPNode::wgWait_);
}
private:
......
......@@ -39,7 +39,6 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
Array<Var> buffer_var_gemm;
};
struct LayoutInferArgs {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment