Unverified Commit a13cde28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TileOp] Implement WGMMA for T.gemm_v2 (#813)

* [Feature] Introduce WGMMA support and enhance GEMM layout handling

- Added support for the WGMMA intrinsic in the TileLang framework, enabling efficient matrix multiplication on newer architectures.
- Refactored GEMM layout functions to accept a boolean parameter for K dimension handling, improving flexibility in layout generation.
- Updated layout inference logic to accommodate new WGMMA configurations and ensure compatibility with existing GEMM operations.
- Enhanced Python bindings for layout functions, allowing for better integration and usability in user-defined operations.
- Improved documentation for layout functions and GEMM operations to clarify usage and parameters.

These changes enhance the performance and usability of GEMM operations, particularly for advanced architectures, while maintaining backward compatibility with existing implementations.

* [Refactor] Clean up code formatting and enhance layout function readability

- Improved code formatting across multiple files for better readability, including consistent indentation and line breaks.
- Updated layout function signatures to enhance clarity, particularly in `gemm_layouts.cc`, `layout.cc`, and `layout.h`.
- Refactored lambda functions in `builtin.cc` and `gemm_py.cc` for improved structure and maintainability.
- Enhanced comments and documentation in layout-related files to clarify usage and parameters.

These changes contribute to a cleaner codebase and improved maintainability of layout functions in the TileLang framework.

* [Feature] Add descriptor initialization and offset manipulation for WGMMA

- Introduced new TileLang builtins `initialize_descriptor` and `increase_descriptor_offset` to facilitate descriptor management for WGMMA operations.
- Updated `builtin.cc` and `builtin.h` to define and document the new builtins, enhancing the framework's capabilities for descriptor handling.
- Modified `codegen_cuda.cc` and `ptx.cc` to integrate the new builtins into the code generation process, ensuring proper assembly generation for WGMMA operations.
- Enhanced the `GemmWGMMA` class to utilize the new descriptor functionalities, improving the efficiency of matrix multiplication operations.
- Updated related tests and documentation to reflect the new features and ensure comprehensive coverage.

These changes enhance the TileLang framework's support for advanced matrix operations on newer architectures, improving performance and usability.

* [Refactor] Improve code formatting and readability in various files

- Enhanced code formatting across multiple files for better readability, including consistent indentation and line breaks.
- Updated function signatures and comments in `builtin.h`, `codegen_cuda.cc`, and `ptx.cc` to improve clarity.
- Refactored descriptor initialization and offset manipulation functions in `builtin.py` and `wgmma_macro_generator.py` for improved structure.
- Cleaned up unnecessary whitespace and improved alignment in `common.h` and `allocate.py`.

These changes contribute to a cleaner and more maintainable codebase in the TileLang framework.

* [Update] Update subproject commit and refactor layout function call

- Updated the subproject commit for `cutlass` to indicate a dirty state.
- Refactored the `UpdateAnalyzer` function in `layout.cc` to call `LayoutNode::getVarMap()` instead of `getVarMap()`, improving clarity and ensuring proper context for variable mapping.

These changes enhance the maintainability and clarity of the layout handling in the TileLang framework.

* support more data types

* gemm_rs support

* lint fix

* wgmma wrapper

* Remove debug logging for wgmma assembly code and refactor swizzle byte size calculations in wgmma macro generator. Enhanced handling of leading and stride byte offsets based on swizzle mode, improving clarity and performance in tensor core intrinsic emissions.

* Refactor GEMM layout functions to replace 'kfactor' with 'k_inner' for improved clarity and consistency. Update includes necessary changes in error messages for Hopper and Sm100 layouts. Additionally, include a new header for CUTE utilities in common.h.

* Comprehensively support WGMMA GEMM SS

* remove debug print

* lint fix

* remove debug print

* reduce bwd test shape

* lint fix

* clear cache for pytest

* lint fix

* Update sparse MLA examples to support SKV adjustment and correctness checks

- Changed SKV parameter from 32768 to 8192 in sparse MLA backward and forward tests.
- Added check_correctness parameter to test functions for validation of outputs.
- Updated test cases to reflect new SKV values and correctness checks.

* test fix

* adjust test case

* test fix

* skip some test currently
parent 10adb79f
...@@ -5,6 +5,7 @@ from tvm import tir ...@@ -5,6 +5,7 @@ from tvm import tir
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy from tilelang.ir import GemmWarpPolicy
from tvm.ir.base import Node from tvm.ir.base import Node
from tvm.ir import PrimExpr
@dataclass @dataclass
...@@ -103,7 +104,7 @@ class GemmBase(object): ...@@ -103,7 +104,7 @@ class GemmBase(object):
return self.gemm_node.offset_B return self.gemm_node.offset_B
@property @property
def clear_accum(self) -> bool: def clear_accum(self) -> PrimExpr:
return self.gemm_node.clear_accum return self.gemm_node.clear_accum
@property @property
......
...@@ -57,7 +57,7 @@ class GemmMMA(GemmBase): ...@@ -57,7 +57,7 @@ class GemmMMA(GemmBase):
raise ValueError( raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
False) False)
warp_row_tiles = int(self.M // m_warp) warp_row_tiles = int(self.M // m_warp)
...@@ -87,6 +87,8 @@ class GemmMMA(GemmBase): ...@@ -87,6 +87,8 @@ class GemmMMA(GemmBase):
B_shared = self.B B_shared = self.B
C_local = self.C C_local = self.C
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
......
from .gemm_base import GemmBase
from tilelang.layout import make_wgmma_swizzled_layout
from tilelang.intrinsics.wgmma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify
class GemmWGMMA(GemmBase):
def infer_layout(self, target: Target, thread_nums: int):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
a_is_k_major = not self.trans_A
b_is_k_major = self.trans_B
if self.is_gemm_ss():
a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp
b_continuity = self.K if b_is_k_major else self.N // n_warp
return {
# WGMMA does not support padding
self.A:
make_wgmma_swizzled_layout(
self.A, continuity=a_continuity, k_major=a_is_k_major),
self.B:
make_wgmma_swizzled_layout(
self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C:
mma_emitter.make_mma_store_layout(self.C),
}
elif self.is_gemm_rs():
b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp
return {
self.A:
mma_emitter.make_mma_load_layout(self.A, matrix="A"),
self.B:
make_wgmma_swizzled_layout(
self.B, continuity=b_continuity, k_major=b_is_k_major),
self.C:
mma_emitter.make_mma_store_layout(self.C),
}
else:
raise ValueError(
f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
True)
warp_row_tiles = int(self.M // m_warp)
warp_col_tiles = int(self.N // n_warp)
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=m_warp,
block_col_warps=n_warp,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
thread_var=thread_var,
)
if self.A in layout_map:
mma_emitter._assign_a_shared_layout(layout_map[self.A])
if self.B in layout_map:
mma_emitter._assign_b_shared_layout(layout_map[self.B])
A_shared = self.A
B_shared = self.B
C_local = self.C
clear_accum = self.clear_accum
if self.is_gemm_ss():
@T.prim_func
def _gemm_ssr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
# Perform Matrix Multiplication
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
raise ValueError(
f"Unsupported gemm combination for wgmma, A: {self.A.scope()}, B: {self.B.scope()}")
def is_gemm_ss(self) -> bool:
return is_shared(self.A) and is_shared(self.B)
def is_gemm_sr(self) -> bool:
return is_shared(self.A) and is_fragment(self.B)
def is_gemm_rs(self) -> bool:
return is_fragment(self.A) and is_shared(self.B)
def is_gemm_rr(self) -> bool:
return is_fragment(self.A) and is_fragment(self.B)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment