"vscode:/vscode.git/clone" did not exist on "2db9e84d3d9d62b2164e3d0345677eaba7a3670f"
Unverified Commit 8fbe1b3a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)

* Add kernel selection option for GEMM v1 in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

* bug fix

* Add kernel selection option for GEMM in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.

* Refactor GEMM macro generator to use BufferRegion instead of Buffer

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.

* Refactor GEMM functions to utilize BufferRegion for improved memory handling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.

* Refactor GEMM code for improved readability and consistency

- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.

* Update GEMM correctness evaluation and macro generator for improved functionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.

* Refactor GEMM and intrinsic files for improved clarity and functionality

- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.

* Enhance GEMM and MMA functionality for FP64 support

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.

* lint fix

* Refactor GEMM correctness evaluation and shared memory alignment handling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.

* Enhance GEMM and MMA implementations with region-based memory handling

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.

* Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.

* Enhance OnArrayDeclaration method to handle repeated buffer declarations

- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.

* Add abbreviation for bfloat16 data type in mfma_macro_generator.py

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.

* Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.

* lint fix

* Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.

* Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.

* Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity.

* Refactor attention sink examples to simplify index calculations

- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.

* Refactor attention sink examples to streamline index calculations

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.

* lint fix

* bugfix

* Refactor reduce operation handling in CUDA and Python

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.

* Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs

* Enhance unit loop handling by refining annotation checks

- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.

* clean clode
parent 2b1f5990
...@@ -105,3 +105,6 @@ cmake-build-*/ ...@@ -105,3 +105,6 @@ cmake-build-*/
# Git version for sdist # Git version for sdist
.git_commit.txt .git_commit.txt
# pre-commit cache
.pre-commit-cache/*
...@@ -65,9 +65,50 @@ else() ...@@ -65,9 +65,50 @@ else()
endif() endif()
# Configs # Configs
set(USE_CUDA OFF) set(TILELANG_BACKENDS CUDA ROCM METAL)
set(USE_ROCM OFF)
set(USE_METAL OFF) set(TILELANG_BACKEND_DOC_CUDA "Enable CUDA backend (ON/OFF/or CUDA SDK path)")
set(TILELANG_BACKEND_DOC_ROCM "Enable ROCm backend (ON/OFF/or ROCm SDK path)")
set(TILELANG_BACKEND_DOC_METAL "Enable Metal backend")
# TVM's config.cmake redefines USE_* options later, so we cache the user's choice
# (including explicit -DUSE_XXX arguments) before we include TVM and restore it
# afterwards.
macro(tilelang_define_backend_option BACKEND)
set(_backend_var "USE_${BACKEND}")
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
set(_user_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
set(_user_override OFF)
if(DEFINED ${_user_override_var})
set(_user_override "${${_user_override_var}}")
endif()
if(DEFINED CACHE{${_backend_var}})
get_property(_cache_type CACHE ${_backend_var} PROPERTY TYPE)
if(_cache_type STREQUAL "UNINITIALIZED")
set(_user_override ON)
endif()
endif()
set(_default OFF)
if(DEFINED ${_backend_var})
set(_default "${${_backend_var}}")
endif()
option(${_backend_var} "${_doc}" "${_default}")
# Remember if the user explicitly set this option so that later logic
# won't auto-toggle backends they configured on the command line.
set(${_user_override_var} ${_user_override} CACHE INTERNAL
"User explicitly set ${_backend_var} during configuration" FORCE)
set(TILELANG_OPTION_${_backend_var} "${${_backend_var}}")
endmacro()
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
tilelang_define_backend_option(${BACKEND})
endforeach()
set(PREBUILD_CYTHON ON) set(PREBUILD_CYTHON ON)
# Configs end # Configs end
...@@ -78,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake) ...@@ -78,6 +119,14 @@ if(EXISTS ${TVM_SOURCE}/cmake/config.cmake)
else() else()
message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.") message(FATAL_ERROR "Nor tvm provided or submodule checkout-ed.")
endif() endif()
# Re-apply TileLang's preferred backend settings after TVM's config may have
# overridden the USE_* cache entries.
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
set(_backend_var "USE_${BACKEND}")
set(_doc "${TILELANG_BACKEND_DOC_${BACKEND}}")
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE)
set(${_backend_var} ${TILELANG_OPTION_${_backend_var}})
endforeach()
# Include directories for TileLang # Include directories for TileLang
set(TILE_LANG_INCLUDES ${TVM_INCLUDES}) set(TILE_LANG_INCLUDES ${TVM_INCLUDES})
...@@ -95,15 +144,26 @@ file(GLOB TILE_LANG_SRCS ...@@ -95,15 +144,26 @@ file(GLOB TILE_LANG_SRCS
src/target/intrin_rule*.cc src/target/intrin_rule*.cc
) )
# Backend-specific checks and configs # Track if the user explicitly selected a backend via cache options.
if($ENV{USE_METAL}) set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
set(_backend_var "USE_${BACKEND}")
set(_override_var "TILELANG_USER_OVERRIDE_${_backend_var}")
if(${_backend_var} OR ${_override_var})
set(TILELANG_BACKEND_USER_SELECTED ON)
endif()
endforeach()
# Only auto-select a backend when the user didn't specify one explicitly.
if(NOT TILELANG_BACKEND_USER_SELECTED)
if($ENV{USE_METAL})
set(USE_METAL ON) set(USE_METAL ON)
elseif(APPLE) elseif(APPLE)
message(STATUS "Enable Metal support by default.") message(STATUS "Enable Metal support by default.")
set(USE_METAL ON) set(USE_METAL ON)
elseif($ENV{USE_ROCM}) elseif($ENV{USE_ROCM})
set(USE_ROCM ON) set(USE_ROCM ON)
else() else()
if($ENV{USE_CUDA}) if($ENV{USE_CUDA})
set(USE_CUDA ON) set(USE_CUDA ON)
elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA}) elseif(DEFINED ENV{USE_CUDA} AND NOT $ENV{USE_CUDA})
...@@ -113,6 +173,7 @@ else() ...@@ -113,6 +173,7 @@ else()
message(STATUS "Enable CUDA support by default.") message(STATUS "Enable CUDA support by default.")
set(USE_CUDA ON) set(USE_CUDA ON)
endif() endif()
endif()
endif() endif()
if(USE_METAL) if(USE_METAL)
...@@ -125,7 +186,7 @@ if(USE_METAL) ...@@ -125,7 +186,7 @@ if(USE_METAL)
elseif(USE_ROCM) elseif(USE_ROCM)
set(CMAKE_HIP_STANDARD 17) set(CMAKE_HIP_STANDARD 17)
include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake) include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake)
find_rocm($ENV{USE_ROCM}) find_rocm(${USE_ROCM})
add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1) add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1)
file(GLOB TILE_LANG_HIP_SRCS file(GLOB TILE_LANG_HIP_SRCS
......
...@@ -81,13 +81,10 @@ def flashattn_fwd( ...@@ -81,13 +81,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0,
if window_size is not None: (bx * block_M - window_size) // block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by // groups, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
...@@ -266,14 +263,11 @@ def flashattn_bwd(batch, ...@@ -266,14 +263,11 @@ def flashattn_bwd(batch,
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32') loop_ed = T.min(
if window_size is not None: T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
loop_ed[0] = T.min( seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
T.ceildiv((by + 1) * block_M + window_size, block_N),
T.ceildiv(seq_len, block_N)) for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......
...@@ -172,14 +172,11 @@ def flashattn( ...@@ -172,14 +172,11 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined( for k in T.Pipelined(
start[0], start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
......
...@@ -78,13 +78,10 @@ def flashattn_fwd( ...@@ -78,13 +78,10 @@ def flashattn_fwd(
sinks[i] = Sinks[by] sinks[i] = Sinks[by]
end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) end = T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0,
if window_size is not None: (bx * block_M - window_size) // block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i q_idx = bx * block_M + i
...@@ -267,14 +264,10 @@ def flashattn_bwd( ...@@ -267,14 +264,10 @@ def flashattn_bwd(
T.clear(dk) T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) loop_st = T.floordiv(by * block_M, block_N)
loop_ed = T.alloc_local([1], 'int32') loop_ed = T.min(
if window_size is not None: T.ceildiv((by + 1) * block_M + window_size, block_N), T.ceildiv(
loop_ed[0] = T.min( seq_len, block_N)) if window_size is not None else T.ceildiv(seq_len, block_N)
T.ceildiv((by + 1) * block_M + window_size, block_N), for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages):
T.ceildiv(seq_len, block_N))
else:
loop_ed[0] = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed[0], num_stages=num_stages):
T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q) T.copy(Q[bz, bx, k * block_N:(k + 1) * block_N, :], q)
T.clear(qkT) T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......
...@@ -162,13 +162,10 @@ def flashattn( ...@@ -162,13 +162,10 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined(start[0], end, num_stages=num_stages): for k in T.Pipelined(start, end, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum) logsum)
......
...@@ -165,14 +165,11 @@ def flashattn( ...@@ -165,14 +165,11 @@ def flashattn(
end = T.min( end = T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N)) T.ceildiv(seq_kv, block_N), T.ceildiv((bx + 1) * block_M + past_len, block_N))
start = T.alloc_local([1], 'int32') start = T.max(0, (bx * block_M + past_len - window_size) //
if window_size is not None: block_N) if window_size is not None else 0
start[0] = T.max(0, (bx * block_M + past_len - window_size) // block_N)
else:
start[0] = 0
for k in T.Pipelined( for k in T.Pipelined(
start[0], start,
end, end,
num_stages=num_stages, num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2], order=[-1, 0, 3, 1, -1, 2],
......
# ruff: noqa # ruff: noqa
import tilelang.testing import tilelang.testing
from topk_selector import test_topk_selector import topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer import fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd import sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined import sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd import sparse_mla_bwd
def test_example_topk_selector(): def test_example_topk_selector():
test_topk_selector() topk_selector.test_topk_selector()
def test_example_fp8_lighting_indexer(): def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd(): def test_example_sparse_mla_fwd():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd( sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
...@@ -28,14 +28,14 @@ def test_example_sparse_mla_fwd(): ...@@ -28,14 +28,14 @@ def test_example_sparse_mla_fwd():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined(): def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing # small shapes for testing
test_sparse_mla_fwd_pipelined( sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd(): def test_example_sparse_mla_bwd():
test_sparse_mla_bwd( sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
......
...@@ -80,7 +80,6 @@ def tl_fused_chunk_fwd_kernel( ...@@ -80,7 +80,6 @@ def tl_fused_chunk_fwd_kernel(
T.atomic_add( T.atomic_add(
O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
o_shared) o_shared)
#TODO: consider using vectorized atomic add or tma reduce for sm90
# Output final state # Output final state
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
...@@ -91,6 +90,7 @@ def tl_fused_chunk_fwd_kernel( ...@@ -91,6 +90,7 @@ def tl_fused_chunk_fwd_kernel(
def tl_fused_chunk_fwd(q, k, v): def tl_fused_chunk_fwd(q, k, v):
B, S, H, D = q.shape B, S, H, D = q.shape
kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
print(kernel.get_kernel_source())
o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
h = kernel(q, k, v, o) h = kernel(q, k, v, o)
return o, h return o, h
......
...@@ -51,13 +51,6 @@ def chunk_retention_fwd_kernel( ...@@ -51,13 +51,6 @@ def chunk_retention_fwd_kernel(
o = T.alloc_fragment([chunk_size, BV], accum_dtype) o = T.alloc_fragment([chunk_size, BV], accum_dtype)
T.clear(h) T.clear(h)
T.annotate_layout({
q: tl.layout.make_swizzled_layout(q),
k: tl.layout.make_swizzled_layout(k),
v: tl.layout.make_swizzled_layout(v),
h_shared: tl.layout.make_swizzled_layout(h_shared),
s_shared: tl.layout.make_swizzled_layout(s_shared),
})
T.use_swizzle(10) T.use_swizzle(10)
for i in T.Pipelined(0, NT): for i in T.Pipelined(0, NT):
......
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
...@@ -52,11 +54,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -52,11 +54,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def main(M=16384, N=16384, K=16384): def main(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128 block_M = 128
block_N = 128 block_N = 128
block_K = 64 block_K = 64
jit_kernel = matmul(M, N, K, block_M, block_N, block_K) jit_kernel = matmul(M, N, K, block_M, block_N, block_K)
print(jit_kernel.get_kernel_source())
import torch import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16) a = torch.randn(M, K, device="cuda", dtype=torch.float16)
......
...@@ -46,8 +46,7 @@ def matmul( ...@@ -46,8 +46,7 @@ def matmul(
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
else: else:
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
# T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.gemm_v2(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
return main return main
...@@ -103,9 +102,11 @@ def run_gemm( ...@@ -103,9 +102,11 @@ def run_gemm(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=2,
num_threads=128, num_threads=128,
): ):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul( program = matmul(
M, M,
N, N,
...@@ -189,9 +190,11 @@ def run_gemm_rs( ...@@ -189,9 +190,11 @@ def run_gemm_rs(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=2,
num_threads=128, num_threads=128,
): ):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_rs( program = matmul_rs(
M, M,
N, N,
...@@ -273,9 +276,11 @@ def run_gemm_sr( ...@@ -273,9 +276,11 @@ def run_gemm_sr(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=2,
num_threads=128, num_threads=128,
): ):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_sr( program = matmul_sr(
M, M,
N, N,
...@@ -361,9 +366,11 @@ def run_gemm_rr( ...@@ -361,9 +366,11 @@ def run_gemm_rr(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=2,
num_threads=128, num_threads=128,
): ):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul_rr( program = matmul_rr(
M, M,
N, N,
...@@ -429,51 +436,51 @@ def _ensure_torch_dtypes(*dtype_names): ...@@ -429,51 +436,51 @@ def _ensure_torch_dtypes(*dtype_names):
def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128) run_gemm_rs(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_rs_false_false(m, n, k): def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rs_true_false(m, n, k): def run_gemm_rs_true_false(m, n, k):
run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rs_true_true(m, n, k): def run_gemm_rs_true_true(m, n, k):
run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128) run_gemm_sr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_sr_false_false(m, n, k): def run_gemm_sr_false_false(m, n, k):
run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_true_false(m, n, k): def run_gemm_sr_true_false(m, n, k):
run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_sr_true_true(m, n, k): def run_gemm_sr_true_true(m, n, k):
run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k, 2, 128) run_gemm_rr(m, n, k * 3, False, True, in_dtype, out_dtype, accum_dtype, m, n, k)
def run_gemm_rr_false_false(m, n, k): def run_gemm_rr_false_false(m, n, k):
run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_true_false(m, n, k): def run_gemm_rr_true_false(m, n, k):
run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
def run_gemm_rr_true_true(m, n, k): def run_gemm_rr_true_true(m, n, k):
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128) run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
TRANS_CASES = [ TRANS_CASES = [
...@@ -516,8 +523,6 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): ...@@ -516,8 +523,6 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
m, m,
n, n,
k, k,
2,
128,
) )
...@@ -537,8 +542,6 @@ def test_gemm_false_false(m, n, k): ...@@ -537,8 +542,6 @@ def test_gemm_false_false(m, n, k):
m, m,
n, n,
k, k,
2,
128,
) )
...@@ -558,8 +561,6 @@ def test_gemm_true_false(m, n, k): ...@@ -558,8 +561,6 @@ def test_gemm_true_false(m, n, k):
m, m,
n, n,
k, k,
2,
128,
) )
...@@ -579,8 +580,6 @@ def test_gemm_true_true(m, n, k): ...@@ -579,8 +580,6 @@ def test_gemm_true_true(m, n, k):
m, m,
n, n,
k, k,
2,
128,
) )
...@@ -724,3 +723,13 @@ if __name__ == "__main__": ...@@ -724,3 +723,13 @@ if __name__ == "__main__":
# print(f"======================= Test {m} {n} {k} False True =============================") # print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128) # run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass") # print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
...@@ -211,7 +211,7 @@ def run_gemm_rs( ...@@ -211,7 +211,7 @@ def run_gemm_rs(
M_VALUES = [64, 128] M_VALUES = [64, 128]
N_VALUES = [16, 32, 64, 128] N_VALUES = [32, 64, 128]
K_VALUES = [16, 32, 64] K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = ([ FALSE_TRUE_CASES = ([
pytest.param( pytest.param(
......
This diff is collapsed.
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp); .def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
} }
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size, std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
Target target, Target target,
GemmInst gemm_inst) const; GemmInst gemm_inst) const;
...@@ -84,47 +84,47 @@ public: ...@@ -84,47 +84,47 @@ public:
class GemmNode : public TileOperatorNode { class GemmNode : public TileOperatorNode {
public: public:
bool CheckWGMMA() const; bool checkWgmma() const;
tir::Buffer A, B, C; tir::Buffer a_, b_, c_;
// pointer to the A, B, C // BufferRegion for A, B and C
PrimExpr Aptr, Bptr, Cptr; BufferRegion aRegion_, bRegion_, cRegion_;
bool trans_A, trans_B; bool transA_, transB_;
int M, N, K; int m_, n_, k_;
int stride_A, stride_B; int strideA_, strideB_;
int offset_A, offset_B; int offsetA_, offsetB_;
PrimExpr clear_accum = const_false(); PrimExpr clearAccum_ = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack_ = 1;
int wg_wait = 0; int wgWait_ = 0;
PrimExpr mbarptr; PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> C_coords; Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy; mutable GemmWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmNode>() refl::ObjectDef<GemmNode>()
.def_ro("A", &GemmNode::A) .def_ro("a", &GemmNode::a_)
.def_ro("B", &GemmNode::B) .def_ro("b", &GemmNode::b_)
.def_ro("C", &GemmNode::C) .def_ro("c", &GemmNode::c_)
.def_ro("Aptr", &GemmNode::Aptr) .def_ro("aRegion", &GemmNode::aRegion_)
.def_ro("Bptr", &GemmNode::Bptr) .def_ro("bRegion", &GemmNode::bRegion_)
.def_ro("Cptr", &GemmNode::Cptr) .def_ro("cRegion", &GemmNode::cRegion_)
.def_ro("trans_A", &GemmNode::trans_A) .def_ro("transA", &GemmNode::transA_)
.def_ro("trans_B", &GemmNode::trans_B) .def_ro("transB", &GemmNode::transB_)
.def_ro("M", &GemmNode::M) .def_ro("m", &GemmNode::m_)
.def_ro("N", &GemmNode::N) .def_ro("n", &GemmNode::n_)
.def_ro("K", &GemmNode::K) .def_ro("k", &GemmNode::k_)
.def_ro("stride_A", &GemmNode::stride_A) .def_ro("strideA", &GemmNode::strideA_)
.def_ro("stride_B", &GemmNode::stride_B) .def_ro("strideB", &GemmNode::strideB_)
.def_ro("offset_A", &GemmNode::offset_A) .def_ro("offsetA", &GemmNode::offsetA_)
.def_ro("offset_B", &GemmNode::offset_B) .def_ro("offsetB", &GemmNode::offsetB_)
.def_ro("clear_accum", &GemmNode::clear_accum) .def_ro("clearAccum", &GemmNode::clearAccum_)
.def_ro("kPack", &GemmNode::kPack) .def_ro("kPack", &GemmNode::kPack_)
.def_ro("wg_wait", &GemmNode::wg_wait) .def_ro("wgWait", &GemmNode::wgWait_)
.def_ro("policy", &GemmNode::policy); .def_ro("policy", &GemmNode::policy_);
} }
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
...@@ -134,9 +134,9 @@ public: ...@@ -134,9 +134,9 @@ public:
TileOperator Clone() const; TileOperator Clone() const;
private: private:
GemmInst GetGemmInst(int block_size, Target target) const; GemmInst getGemmInst(int block_size, Target target) const;
bool AllowTCGEN5MMA(Target target) const; bool allowTcgen5Mma(Target target) const;
bool AllowWGMMA(int block_size, Target target) const; bool allowWgmma(int block_size, Target target) const;
mutable bool completed_ = false; mutable bool completed_ = false;
}; };
......
...@@ -11,16 +11,102 @@ ...@@ -11,16 +11,102 @@
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "tvm/ffi/string.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
/** /**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer * @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map. * map.
...@@ -51,45 +137,42 @@ using namespace tir; ...@@ -51,45 +137,42 @@ using namespace tir;
*/ */
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>(); ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1]; node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->Cptr = args[2]; node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; node->a_ = node->aRegion_->buffer;
node->trans_A = args[3].as<Bool>().value(); node->b_ = node->bRegion_->buffer;
node->trans_B = args[4].as<Bool>().value(); node->c_ = node->cRegion_->buffer;
node->M = args[5].as<IntImm>().value()->value; node->transA_ = args[3].as<Bool>().value();
node->N = args[6].as<IntImm>().value()->value; node->transB_ = args[4].as<Bool>().value();
node->K = args[7].as<IntImm>().value()->value; node->m_ = args[5].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value); node->n_ = args[6].as<IntImm>().value()->value;
node->clear_accum = args[9].as<PrimExpr>().value(); node->k_ = args[7].as<IntImm>().value()->value;
node->stride_A = args[10].as<IntImm>().value()->value; node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->stride_B = args[11].as<IntImm>().value()->value; node->clearAccum_ = args[9].as<PrimExpr>().value();
node->offset_A = args[12].as<IntImm>().value()->value; node->strideA_ = args[10].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value; node->strideB_ = args[11].as<IntImm>().value()->value;
node->offsetA_ = args[12].as<IntImm>().value()->value;
node->offsetB_ = args[13].as<IntImm>().value()->value;
if (args.size() > 14) { if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value; node->kPack_ = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) { if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2"; ICHECK(false) << "kPack must be 1 or 2";
} }
} }
if (args.size() > 15) { if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value; node->wgWait_ = args[15].as<IntImm>().value()->value;
} }
if (args.size() > 16) { node->mbarPtr_ = args[16];
node->mbarptr = args[16]; if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else { } else {
node->mbarptr = IntImm(DataType::UInt(32), 0); node->mbar_ = std::nullopt;
}
if (args.size() > 18) {
node->C_coords = Array<PrimExpr>({args[17], args[18]});
} else if (args.size() > 17) {
node->C_coords = Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
} else {
node->C_coords = Array<PrimExpr>(
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
} }
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -106,28 +189,28 @@ TileOperator GemmPyNode::Clone() const { ...@@ -106,28 +189,28 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op); return GemmPy(op);
} }
bool GemmPyNode::AllowTCGEN5MMA(Target target) const { bool GemmPyNode::allowTcgen5Mma(Target target) const {
return TargetIsSm100(target) && return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" || ((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
A.scope() == "shared.tmem") && a_.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") && (b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
C.scope() == "shared.tmem") && c_.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first; GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
} }
bool GemmPyNode::AllowWGMMA(int block_size, Target target) const { bool GemmPyNode::allowWgmma(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target); int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size; int num_warps = block_size / warp_size;
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) && return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA(); checkWgmma();
} }
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target); bool allow_tcgen5mma = allowTcgen5Mma(target);
bool allow_wgmma = AllowWGMMA(block_size, target); bool allow_wgmma = allowWgmma(block_size, target);
if (allow_tcgen5mma) { if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA; return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) { } else if (allow_wgmma) {
...@@ -175,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { ...@@ -175,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
* @return true if WGMMA is supported for the current buffers, dtypes, and * @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise. * transpose/shape constraints; false otherwise.
*/ */
bool GemmPyNode::CheckWGMMA() const { bool GemmPyNode::checkWgmma() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") { if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
return false; return false;
} }
if (C->dtype == DataType::Float(16)) { if (c_->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return K % 16 == 0; return k_ % 16 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
} else if (C->dtype == DataType::Float(32)) { } else if (c_->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return K % 16 == 0; return k_ % 16 == 0;
else if (A->dtype == DataType::BFloat(16) && else if (a_->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16)) b_->dtype == DataType::BFloat(16))
return K % 16 == 0; return k_ % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32)) else if (a_->dtype == DataType::Float(32) &&
return (!trans_A) && trans_B && K % 8 == 0; b_->dtype == DataType::Float(32))
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3()) return (!transA_) && transB_ && k_ % 8 == 0;
return (!trans_A) && trans_B && K % 32 == 0; else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2()) return (!transA_) && transB_ && k_ % 32 == 0;
return (!trans_A) && trans_B && K % 32 == 0; else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3()) return (!transA_) && transB_ && k_ % 32 == 0;
return (!trans_A) && trans_B && K % 32 == 0; else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2()) return (!transA_) && transB_ && k_ % 32 == 0;
return (!trans_A) && trans_B && K % 32 == 0; else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
} else if (C->dtype == DataType::Int(32)) { } else if (c_->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8)) if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8)) else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8)) else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8)) else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
} else { } else {
...@@ -256,10 +340,10 @@ static int GetArchInt(Target target) { ...@@ -256,10 +340,10 @@ static int GetArchInt(Target target) {
Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent); auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target); GemmInst gemm_inst = getGemmInst(block_size, T.target);
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst); policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = auto prim_func =
...@@ -302,6 +386,14 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, ...@@ -302,6 +386,14 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>( results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds)); (*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
// Bind all fragment layouts with the provided thread range
for (auto kv : results) {
const Buffer &buf = kv.first;
const Layout &layout = kv.second;
if (auto frag = layout.as<Fragment>()) {
results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds));
}
}
} else { } else {
LOG(FATAL) << "No infer layout function found for gemm_py"; LOG(FATAL) << "No infer layout function found for gemm_py";
} }
...@@ -321,7 +413,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -321,7 +413,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst", refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) { [](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target); return gemm_py->getGemmInst(block_size, target);
}); });
} }
......
...@@ -18,51 +18,52 @@ using namespace tir; ...@@ -18,51 +18,52 @@ using namespace tir;
class GemmPyNode : public TileOperatorNode { class GemmPyNode : public TileOperatorNode {
public: public:
bool CheckWGMMA() const; bool checkWgmma() const;
bool AllowTCGEN5MMA(Target target) const; bool allowTcgen5Mma(Target target) const;
bool AllowWGMMA(int block_size, Target target) const; bool allowWgmma(int block_size, Target target) const;
tir::Buffer A, B, C; tir::Buffer a_, b_, c_;
// pointer to the A, B, C // BufferRegion for A, B and C
PrimExpr Aptr, Bptr, Cptr; BufferRegion aRegion_, bRegion_, cRegion_;
bool trans_A, trans_B; bool transA_, transB_;
int M, N, K; int m_, n_, k_;
int stride_A, stride_B; int strideA_, strideB_;
int offset_A, offset_B; int offsetA_, offsetB_;
PrimExpr clear_accum = const_false(); PrimExpr clearAccum_ = const_false();
PrimExpr mbarptr; PrimExpr mbarPtr_;
Array<PrimExpr> C_coords; std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack_ = 1;
int wg_wait = 0; int wgWait_ = 0;
mutable GemmWarpPolicy policy; mutable GemmWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmPyNode>() refl::ObjectDef<GemmPyNode>()
.def_ro("A", &GemmPyNode::A) .def_ro("a", &GemmPyNode::a_)
.def_ro("B", &GemmPyNode::B) .def_ro("b", &GemmPyNode::b_)
.def_ro("C", &GemmPyNode::C) .def_ro("c", &GemmPyNode::c_)
.def_ro("Aptr", &GemmPyNode::Aptr) .def_ro("aRegion", &GemmPyNode::aRegion_)
.def_ro("Bptr", &GemmPyNode::Bptr) .def_ro("bRegion", &GemmPyNode::bRegion_)
.def_ro("Cptr", &GemmPyNode::Cptr) .def_ro("cRegion", &GemmPyNode::cRegion_)
.def_ro("trans_A", &GemmPyNode::trans_A) .def_ro("transA", &GemmPyNode::transA_)
.def_ro("trans_B", &GemmPyNode::trans_B) .def_ro("transB", &GemmPyNode::transB_)
.def_ro("M", &GemmPyNode::M) .def_ro("m", &GemmPyNode::m_)
.def_ro("N", &GemmPyNode::N) .def_ro("n", &GemmPyNode::n_)
.def_ro("K", &GemmPyNode::K) .def_ro("k", &GemmPyNode::k_)
.def_ro("stride_A", &GemmPyNode::stride_A) .def_ro("strideA", &GemmPyNode::strideA_)
.def_ro("stride_B", &GemmPyNode::stride_B) .def_ro("strideB", &GemmPyNode::strideB_)
.def_ro("offset_A", &GemmPyNode::offset_A) .def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offset_B", &GemmPyNode::offset_B) .def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clear_accum", &GemmPyNode::clear_accum) .def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarptr", &GemmPyNode::mbarptr) .def_ro("mbarPtr", &GemmPyNode::mbarPtr_)
.def_ro("C_coords", &GemmPyNode::C_coords) .def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack) .def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wg_wait", &GemmPyNode::wg_wait) .def_ro("wgWait", &GemmPyNode::wgWait_)
.def_ro("policy", &GemmPyNode::policy); .def_ro("policy", &GemmPyNode::policy_);
} }
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
...@@ -72,7 +73,7 @@ public: ...@@ -72,7 +73,7 @@ public:
TileOperator Clone() const; TileOperator Clone() const;
// Target GEMM instruction // Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const; GemmInst getGemmInst(int block_size, Target target) const;
private: private:
mutable bool completed_ = false; mutable bool completed_ = false;
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N, std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
int block_size, int block_size,
Target target, Target target,
bool use_wgmma, bool use_wgmma,
int bits) const { int bits) const {
int num_warps = block_size / TargetGetWarpSize(target); int num_warps = block_size / TargetGetWarpSize(target);
auto [m_warp, n_warp] = GemmWarpPolicyNode::ComputeWarpPartition( auto [m_warp, n_warp] = GemmWarpPolicyNode::computeWarpPartition(
M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA); M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA);
// Special handling for gemm_sp when the tiling size is not a multiple // Special handling for gemm_sp when the tiling size is not a multiple
...@@ -85,25 +85,25 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N, ...@@ -85,25 +85,25 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
*/ */
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>(); ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])]; node->a_ = vmap[GetVarFromAccessPtr(args[0])];
node->E = vmap[GetVarFromAccessPtr(args[1])]; node->e_ = vmap[GetVarFromAccessPtr(args[1])];
node->B = vmap[GetVarFromAccessPtr(args[2])]; node->b_ = vmap[GetVarFromAccessPtr(args[2])];
node->C = vmap[GetVarFromAccessPtr(args[3])]; node->c_ = vmap[GetVarFromAccessPtr(args[3])];
node->trans_A = args[4].as<Bool>().value(); node->transA_ = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value(); node->transB_ = args[5].as<Bool>().value();
node->M = args[6].as<IntImm>().value()->value; node->m_ = args[6].as<IntImm>().value()->value;
node->N = args[7].as<IntImm>().value()->value; node->n_ = args[7].as<IntImm>().value()->value;
node->K = args[8].as<IntImm>().value()->value; node->k_ = args[8].as<IntImm>().value()->value;
node->policy = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value); node->policy_ = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value(); node->clearAccum_ = args[10].as<Bool>().value();
if (args.size() > 11) { if (args.size() > 11) {
node->kPack = args[11].as<IntImm>().value()->value; node->kPack_ = args[11].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) { if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2"; ICHECK(false) << "kPack must be 1 or 2";
} }
} }
if (args.size() > 12) { if (args.size() > 12) {
node->wg_wait = args[12].as<IntImm>().value()->value; node->wgWait_ = args[12].as<IntImm>().value()->value;
} }
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32; int warp_size = 32;
auto block_size = *as_const_int(T.thread_bounds->extent); auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) && bool maybe_wgmma = TargetIsHopper(T.target) && (this->m_ >= 64) &&
(block_size / warp_size % 4 == 0); (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy->ComputeWarpPartition( auto [warp_m, warp_n] = policy_->computeWarpPartition(
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits()); m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits());
std::stringstream ss; std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss"; std::string op_name = "tl::gemm_sp_ss";
ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") && ICHECK((a_.scope() == "shared" || a_.scope() == "shared.dyn") &&
(B.scope() == "shared" || B.scope() == "shared.dyn")) (b_.scope() == "shared" || b_.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received " << A.scope() << "Only support shared.dyn scope for A and B, but received "
<< " and " << B.scope(); << a_.scope() << " and " << b_.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn")) ICHECK((e_.scope() == "shared" || e_.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are " << "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implementation, found " "delegated to cute implementation, found "
<< E.scope(); << e_.scope();
ss << op_name << "<" << M << ", " << N << ", " << K << ", "; ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
ss << warp_m << ", " << warp_n << ", "; ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B; ss << transA_ << ", " << transB_;
ss << ", " << clear_accum; ss << ", " << clearAccum_;
if (TargetIsHopper(T.target)) { if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false"); ss << ", " << (maybe_wgmma ? "true" : "false");
} }
if (wg_wait != 0) { if (wgWait_ != 0) {
ss << ", " << wg_wait; ss << ", " << wgWait_;
} }
ss << ">"; ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A; auto A_buffer = T.buffer_remap.count(a_) ? T.buffer_remap[a_] : a_;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B; auto B_buffer = T.buffer_remap.count(b_) ? T.buffer_remap[b_] : b_;
auto C_buffer = T.buffer_remap[C]; auto C_buffer = T.buffer_remap[c_];
auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E; auto E_buffer = T.buffer_remap.count(e_) ? T.buffer_remap[e_] : e_;
auto new_call = auto new_call =
Call(DataType::Handle(), tl::tl_gemm_sp(), Call(DataType::Handle(), tl::tl_gemm_sp(),
...@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, ...@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
if (completed_) if (completed_)
return {}; return {};
LayoutMap results; LayoutMap results;
ICHECK(C.scope() == "local.fragment"); ICHECK(c_.scope() == "local.fragment");
auto thread_range = T.thread_bounds; auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent); auto block_size = *as_const_int(thread_range->extent);
if (TargetIsHopper(T.target)) { if (TargetIsHopper(T.target)) {
const int warp_size = 32; const int warp_size = 32;
constexpr int wgmma_m = 16 * 4; constexpr int wgmma_m = 16 * 4;
bool maybe_wgmma = bool maybe_wgmma =
(this->M >= wgmma_m) && (block_size / warp_size % 4 == 0); (this->m_ >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy->ComputeWarpPartition( auto [warp_m, warp_n] = policy_->computeWarpPartition(
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits()); m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits());
auto fragment = auto fragment = maybe_wgmma
maybe_wgmma ? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, n_ / warp_n, c_->dtype.bits())
C->dtype.bits()) : makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); c_->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(c_, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous, results.Set(a_, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(), mat_continuous, a_->dtype.bits(),
trans_A ? 1 : 2)); transA_ ? 1 : 2));
} else { } else {
ICHECK(false) << "Not implemented"; ICHECK(false) << "Not implemented";
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = B->shape.size(); int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]); const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
const int64_t continuity = const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n; transB_ ? mat_continuous : mat_continuous / warp_n;
results.Set(B, results.Set(b_,
makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)); b_->dtype.bits(), transB_ ? 2 : 1));
} else { } else {
ICHECK(false) << "WGMMA only support B in shared."; ICHECK(false) << "WGMMA only support B in shared.";
} }
} else if (TargetIsAmpere(T.target)) { } else if (TargetIsAmpere(T.target)) {
auto [warp_m, warp_n] = policy->ComputeWarpPartition( auto [warp_m, warp_n] = policy_->computeWarpPartition(
M, N, block_size, T.target, false, A->dtype.bits()); m_, n_, block_size, T.target, false, a_->dtype.bits());
auto fragment = auto fragment = makeGemmSparseFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
makeGemmSparseFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); c_->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range)); results.Set(c_, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(A, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, results.Set(a_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
A->dtype.bits())); a_->dtype.bits()));
} else if (A.scope() == "local.fragment") { } else if (a_.scope() == "local.fragment") {
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, // auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// A->dtype.bits(), trans_A); // A->dtype.bits(), trans_A);
// results.Set(A, fragment->BindThreadRange(thread_range)); // results.Set(A, fragment->BindThreadRange(thread_range));
...@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, ...@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
} else { } else {
ICHECK(0); ICHECK(0);
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = B->shape.size(); int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]); const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
results.Set(B, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous, results.Set(b_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
B->dtype.bits())); b_->dtype.bits()));
} else if (B.scope() == "local.fragment") { } else if (b_.scope() == "local.fragment") {
// auto fragment = // auto fragment =
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); // makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// results.Set(B, fragment->BindThreadRange(thread_range)); // results.Set(B, fragment->BindThreadRange(thread_range));
......
...@@ -18,7 +18,7 @@ using namespace tir; ...@@ -18,7 +18,7 @@ using namespace tir;
class GemmSPWarpPolicyNode : public GemmWarpPolicyNode { class GemmSPWarpPolicyNode : public GemmWarpPolicyNode {
public: public:
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size, std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma, Target target, bool use_wgmma,
int bits) const; int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
...@@ -53,16 +53,16 @@ public: ...@@ -53,16 +53,16 @@ public:
class GemmSPNode : public TileOperatorNode { class GemmSPNode : public TileOperatorNode {
public: public:
tir::Buffer A, B, C, E; tir::Buffer a_, b_, c_, e_;
bool trans_A, trans_B; bool transA_, transB_;
int M, N, K; int m_, n_, k_;
bool clear_accum = false; bool clearAccum_ = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack_ = 1;
int wg_wait = 0; int wgWait_ = 0;
mutable GemmSPWarpPolicy policy; mutable GemmSPWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
...@@ -74,19 +74,19 @@ public: ...@@ -74,19 +74,19 @@ public:
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>() refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy) .def_ro("policy", &GemmSPNode::policy_)
.def_ro("A", &GemmSPNode::A) .def_ro("a", &GemmSPNode::a_)
.def_ro("B", &GemmSPNode::B) .def_ro("b", &GemmSPNode::b_)
.def_ro("C", &GemmSPNode::C) .def_ro("c", &GemmSPNode::c_)
.def_ro("E", &GemmSPNode::E) .def_ro("e", &GemmSPNode::e_)
.def_ro("trans_A", &GemmSPNode::trans_A) .def_ro("transA", &GemmSPNode::transA_)
.def_ro("trans_B", &GemmSPNode::trans_B) .def_ro("transB", &GemmSPNode::transB_)
.def_ro("M", &GemmSPNode::M) .def_ro("m", &GemmSPNode::m_)
.def_ro("N", &GemmSPNode::N) .def_ro("n", &GemmSPNode::n_)
.def_ro("K", &GemmSPNode::K) .def_ro("k", &GemmSPNode::k_)
.def_ro("clear_accum", &GemmSPNode::clear_accum) .def_ro("clearAccum", &GemmSPNode::clearAccum_)
.def_ro("kPack", &GemmSPNode::kPack) .def_ro("kPack", &GemmSPNode::kPack_)
.def_ro("wg_wait", &GemmSPNode::wg_wait); .def_ro("wgWait", &GemmSPNode::wgWait_);
} }
private: private:
......
...@@ -39,7 +39,6 @@ struct LowerArgs { ...@@ -39,7 +39,6 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
Array<Var> buffer_var_gemm;
}; };
struct LayoutInferArgs { struct LayoutInferArgs {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment