Unverified Commit a5074fd5 authored by coderabbitai[bot]'s avatar coderabbitai[bot] Committed by GitHub
Browse files

📝 Add docstrings to `fix` (#726)

Docstrings generation was requested by @LeiWang1999.

* https://github.com/tile-ai/tilelang/pull/712#issuecomment-3190680851



The following files were modified:

* `src/op/gemm.cc`
* `src/tl_templates/cuda/gemm_sm90.h`
* `src/transform/warp_specialized_rewriter.cc`
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent f4a828f6
......@@ -78,6 +78,46 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
}
}
/**
* @brief Compute how warps are partitioned between the M and N GEMM dimensions.
*
* Determines the number of warps assigned to the M (rows) and N (columns)
* dimensions for a block given the selected GEMM implementation and target.
* The function enforces constraints required by the implementations (e.g.,
* per-warp tile sizes) and adapts the partition according to the configured
* GemmWarpPolicy (FullRow, FullCol, Square).
*
* @param block_size Total number of threads in the block (used to derive num_warps).
* @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
* @param target Target device information (used for warp size and target-specific rules).
* @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp == num_warps.
*
* Constraints and behavior:
* - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
* checks that M % 16 == 0 and N % 8 == 0.
* - num_warps is computed as block_size / warp_size(target).
* - For WGMMA (kWGMMA):
* - num_warps must be a multiple of 4 (warp-groups of 4).
* - m_warp is always a multiple of 4.
* - The warp partition respects the GemmWarpPolicy:
* - FullRow: maximize warps on M (in multiples of 4) while keeping divisibility.
* - FullCol: maximize warps on N, but if N is not evenly divisible, move
* whole warp-groups to M to achieve feasibility.
* - Square: choose a multiple-of-4 m_warp that best balances per-warp work
* between M and N.
* - For non-WGMMA implementations:
* - FullRow: favor allocating warps to M first; if M cannot use all warps,
* remaining warps are placed on N.
* - FullCol: favor allocating warps to N first; if N cannot use all warps,
* remaining warps are placed on M.
* - Square: search for the m/n split that best balances per-warp work given
* integer warp counts and the per-warp tile sizes.
*
* Error handling:
* - The function performs internal checks (ICHECK) and will fail if required
* divisibility or policy conditions are not met (e.g., M/N tile divisibility,
* invalid policy, or WGMMA-specific warp-group requirements).
*/
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
GemmInst gemm_inst,
Target target) const {
......@@ -240,6 +280,34 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
return {m_warp, n_warp};
}
/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool Gemm::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
......@@ -342,6 +410,29 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(new_call);
}
/**
* @brief Infer memory/layout mappings for A, B, and C buffers for this GEMM op.
*
* Generates and returns a LayoutMap that binds buffer A, B, and C to
* target- and architecture-specific fragment or shared-memory layouts based
* on the current target, thread bounds, warp partitioning, data types, and
* transpose flags. This performs target dispatch (Volta, Ampere/Turing/SM120,
* Hopper, CDNA), selects the appropriate fragment or shared layout creators,
* and binds fragment layouts to the thread range when buffers are local
* fragments.
*
* Preconditions:
* - C.scope() must be "local.fragment".
*
* Postconditions / side effects:
* - Marks the operator's layout inference as completed (sets completed_ = true).
* - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
* incompatible shape constraints.
*
* @param T Layout inference inputs (thread bounds and target).
* @param level Inference level (unused for side effects but retained for API).
* @return LayoutMap mapping each of A, B, and C to their inferred layouts.
*/
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (completed_)
return {};
......
......@@ -533,7 +533,85 @@ public:
} // namespace tl_mma
} // namespace cute
} /**
* Execute a tiled GEMM where both A and B tiles are sourced from shared memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body to perform the computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Execute a tiled GEMM where A is read from global memory and B is staged in shared memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Execute a tiled GEMM where A is staged in shared memory and B is read from global memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM (both operands in shared memory or selected backend) and write to accum.
*
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
* the Hopper wgmma implementation; otherwise dispatches to the tl_mma implementation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM with A in global memory and B in shared memory (or selected backend).
*
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
* the Hopper wgmma read-share implementation; otherwise dispatches to the tl_mma read-share.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM with A staged in shared memory and B in global memory (tl_mma only).
*
* wgmma does not support this variant; caller must set use_wgmma == false.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Wait for a warp-group of WMMA/MMA warps to complete.
*
* Wrapper around cute::warpgroup_wait for the specified number of MMA warps.
*/
/**
* Synchronize a named barrier across NumMmaThreads MMA threads.
*
* Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id.
*/
/**
* Arrive at a named barrier for NumMmaThreads MMA threads using architecture-aware mapping.
*
* Supported NumMmaThreads values: 256 or 384. The function issues one or two barrier arrives
* depending on the thread-group topology to ensure proper rendezvous ordering.
*/
/**
* Initialize named-barrier state for multi-warp MMA execution.
*
* For NumMmaThreads == 256 or 384, performs the required initial barrier arrivals for
* non-zero canonical warp-group indices to set up subsequent barrier synchronization.
*/
namespace tl {
......@@ -603,7 +681,23 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
TL_DEVICE /**
* Perform a read-share (B in shared memory, A in global) tiled GEMM and accumulate into `accum`.
*
* Dispatches at compile time to either the Hopper wgmma implementation or the fallback MMA implementation
* depending on `use_wgmma`. The selected GemmTensorOp::body_rs performs the region-tiled GEMM loop and
* updates the accumulator in-place.
*
* When `use_wgmma == true`, this function enforces wgmma constraints at compile time:
* - A's leading dimension must equal (trans_A ? M : K)
* - B's leading dimension must equal (trans_B ? K : N)
* - offset_a and offset_b must be zero
*
* @param pA Pointer to operand A (global memory). Layout/stride expectations depend on template parameters.
* @param pB Pointer to operand B (base for shared-memory staging). Layout/stride expectations depend on template parameters.
* @param accum Pointer to the accumulator/output C buffer updated in-place.
*/
void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
"Hopper wgmma doesn't support custom stride for A");
......@@ -628,7 +722,18 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
TL_DEVICE /**
* Perform a non-wgmma tiled GEMM where A regions are staged into shared memory
* and B is read directly from global memory, accumulating into `accum`.
*
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr implementation.
* Must be instantiated with `use_wgmma = false` (enforced via static_assert).
*
* @param pA Pointer to the A operand in global memory (source that will be staged to shared memory).
* @param pB Pointer to the B operand in global memory (read directly).
* @param accum Pointer to the output accumulator matrix in global memory.
*/
void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
......@@ -637,7 +742,13 @@ TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
MMA::body_sr(pA, pB, accum);
}
template <int num_mma> TL_DEVICE void wait_wgmma() {
template <int num_mma> TL_DEVICE /**
* Wait for all WMMA/MMA warps in the current warp-group to synchronize.
*
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes completes,
* ensuring all participating warps have arrived before proceeding.
*/
void wait_wgmma() {
cute::warpgroup_wait<num_mma>();
}
......
......@@ -569,6 +569,29 @@ public:
class WSCodeEmitter : public StmtMutator {
public:
/**
* @brief Construct a warp-specialized code emitter configured for producer or consumer emission.
*
* Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered code for a single
* warp-specialized block. The emitter is configured with the loop/thread iteration variable,
* buffer mapping, role marker used to classify statements, and two flags that control emission
* behavior:
*
* - `mbarrier_only`: when true, emission is restricted to barrier-related operations only.
* - `only_has_wgmma`: when true, the emitter will account for the presence of WgMMA
* (workgroup MMA) operations when computing barrier/thread gating behavior.
*
* @param is_emitting_producer True to emit producer-side groups; false to emit consumer-side groups.
* @param thread_iv IterVar representing the thread iteration variable (threadIdx.*) whose Var is used
* for thread-index rewrites and gating.
* @param buffer_data_to_buffer Map from buffer data Var to the corresponding Buffer (used to resolve
* buffer references during emission).
* @param marker Role marker that classifies statements as producer/consumer/both; used to filter
* which statements are emitted on this path.
* @param mbarrier_only If true, restrict emission to mbarrier-related statements and helpers.
* @param only_has_wgmma If true, adjust emission and barrier-thread-count logic for blocks that
* contain WgMMA operations.
*/
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
......@@ -578,7 +601,15 @@ public:
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
only_has_wgmma_(only_has_wgmma) {}
bool hasSimtCopy() const { return has_simt_copy_; }
/**
* @brief Whether a SIMT-style bulk copy was detected.
*
* Returns true when a simulated SIMT (thread-parallel) copy pattern was observed
* during analysis/emission, which can affect barrier insertion and copy emission.
*
* @return true if a SIMT copy was detected; false otherwise.
*/
bool hasSimtCopy() const { return has_simt_copy_; }
private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
......@@ -596,7 +627,47 @@ private:
}
}
// TODO: only need to add block for ops in the loop
/**
* @brief Visit and transform a SeqStmt node, emitting grouped blocks with barrier
* synchronization according to producer/consumer roles.
*
* This method examines the sequence to determine whether producer-side
* synchronization is required (based on marker_ roles). If no producer sync is
* needed it delegates to FilterByRole. Otherwise it:
* - Recursively visits and transforms each child statement.
* - Extracts an acquire/release sync pattern for the sequence via
* ExtractSyncPattern.
* - For producer emission (is_emitting_producer_ == true):
* - Skips consumer-only statements unless marker_ marks a statement as Both,
* in which case the statement is emitted as its own group.
* - For each statement, inserts parity waits for acquire patterns, rewrites
* release statements with MbarrierRewriter using a computed barrier id,
* collects SimT-copy presence (setting has_simt_copy_ and inserting
* cp.async barriers when found), optionally emits arrive barriers for
* release-after events, and emits each resulting set of statements as a
* group block annotated with "stmt_group".
* - For consumer emission (is_emitting_producer_ == false):
* - Skips producer-only statements.
* - Inserts parity waits for acquire patterns, appends the transformed
* statement, and emits arrive barriers for release-after events. When
* only_has_wgmma_ is set, the arrive barrier uses a per-thread predicate
* (FloorMod(thread_var_,128)==0) with CTA=0; otherwise a full arrive is
* emitted.
* - Recomputes pipeline_info_ to drop producer-only ops.
*
* Side effects / state updates:
* - Increments num_barriers_ by (number of extracted patterns * num_stages_).
* - May set has_simt_copy_ when a SimT copy is detected in producer rewrites.
* - Inserts barrier ids into released_barrier_ for release-after events.
* - Updates pipeline_info_ for the consumer path to remove producer ops.
*
* The resulting statements are emitted as grouped blocks (via MakeGroupBlock)
* with the annotation "stmt_group" and returned as either a single Stmt (if
* there's only one group) or a SeqStmt containing the grouped blocks.
*
* @return Stmt The transformed statement (either a single group block or a
* SeqStmt of group blocks).
*/
Stmt VisitStmt_(const SeqStmtNode *op) final {
bool has_producer = false;
......@@ -1176,6 +1247,38 @@ private:
return for_node;
}
/**
* @brief Rewrite a BlockRealize for warp specialization, inserting barriers and
* emitting producer/consumer bodies.
*
* This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_)
* is defined and warp-specialization is applicable. It:
* - Determines producer/consumer roles via WarpSpecializedRoleMarker and
* returns the original block if no producer is detected.
* - If warp specialization is disabled, emits only mbarrier initialization and
* the mbarrier-only transformed body.
* - Otherwise, detects WgMMA usage for the block body and constructs separate
* WSCodeEmitter instances for producer and consumer paths (propagating the
* WgMMA flag to the consumer emitter).
* - Generates producer/consumer code, applies register hint calls (set_max_nreg)
* when available, and rewrites thread indices with ThreadIdxRewriter to
* partition threads between producer and consumer roles.
* - Computes and initializes a list of mbarrier handles with per-barrier
* arrive thread counts (taking SIMT-copy and WgMMA cases into account).
* - Wraps the transformed body in an IfThenElse that dispatches producer vs
* consumer based on thread index, and annotates the region with the
* "kWarpSpecializationScope" attribute that contains producer/consumer
* thread extents.
*
* Side effects:
* - May update member state: only_has_wgmma_, updated_thread_extent_,
* need_update_thread_extent_.
* - May abort via ICHECK if invariants (e.g., matching barrier counts) are
* violated.
*
* @return The possibly rewritten BlockRealize statement (original when no
* warp-specialization is applied or thread_iv_ is undefined).
*/
Stmt VisitStmt_(const BlockRealizeNode *op) final {
BlockRealize block_realize =
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment