"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ce3f2db7c1c09799e374aaeaf45b3a2500e23fd8"
Unverified Commit e68fdab8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Add DispatchInstruction specialization for fp8 types in gemm_sm90.h (#751)

- Introduced specialized DispatchInstruction templates for fp8_e4_t and fp8_e5_t types, enhancing support for new data formats in CUDA GEMM operations.
- Each specialization defines the corresponding MMA and MMA_Group types, optimizing performance for specific configurations.
parent 796b3bbe
......@@ -153,6 +153,19 @@ struct DispatchInstruction;
using _X = Underscore;
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
......@@ -533,55 +546,56 @@ public:
} // namespace tl_mma
} /**
* Execute a tiled GEMM where both A and B tiles are sourced from shared memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body to perform the computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
} // namespace cute
/**
* Execute a tiled GEMM where A is read from global memory and B is staged in shared memory.
* Execute a tiled GEMM where A is read from global memory and B is staged in
* shared memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the computation.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the
* computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Execute a tiled GEMM where A is staged in shared memory and B is read from global memory.
* Execute a tiled GEMM where A is staged in shared memory and B is read from
* global memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the computation.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the
* computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM (both operands in shared memory or selected backend) and write to accum.
* Perform a tiled GEMM (both operands in shared memory or selected backend) and
* write to accum.
*
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
* the Hopper wgmma implementation; otherwise dispatches to the tl_mma implementation.
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and
* dispatches to the Hopper wgmma implementation; otherwise dispatches to the
* tl_mma implementation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM with A in global memory and B in shared memory (or selected backend).
* Perform a tiled GEMM with A in global memory and B in shared memory (or
* selected backend).
*
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
* the Hopper wgmma read-share implementation; otherwise dispatches to the tl_mma read-share.
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and
* dispatches to the Hopper wgmma read-share implementation; otherwise
* dispatches to the tl_mma read-share.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM with A staged in shared memory and B in global memory (tl_mma only).
* Perform a tiled GEMM with A staged in shared memory and B in global memory
* (tl_mma only).
*
* wgmma does not support this variant; caller must set use_wgmma == false.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr.
......@@ -601,16 +615,19 @@ public:
* Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id.
*/
/**
* Arrive at a named barrier for NumMmaThreads MMA threads using architecture-aware mapping.
* Arrive at a named barrier for NumMmaThreads MMA threads using
* architecture-aware mapping.
*
* Supported NumMmaThreads values: 256 or 384. The function issues one or two barrier arrives
* depending on the thread-group topology to ensure proper rendezvous ordering.
* Supported NumMmaThreads values: 256 or 384. The function issues one or two
* barrier arrives depending on the thread-group topology to ensure proper
* rendezvous ordering.
*/
/**
* Initialize named-barrier state for multi-warp MMA execution.
*
* For NumMmaThreads == 256 or 384, performs the required initial barrier arrivals for
* non-zero canonical warp-group indices to set up subsequent barrier synchronization.
* For NumMmaThreads == 256 or 384, performs the required initial barrier
* arrivals for non-zero canonical warp-group indices to set up subsequent
* barrier synchronization.
*/
namespace tl {
......@@ -682,22 +699,29 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE /**
* Perform a read-share (B in shared memory, A in global) tiled GEMM and accumulate into `accum`.
* Perform a read-share (B in shared memory, A in global) tiled GEMM
* and accumulate into `accum`.
*
* Dispatches at compile time to either the Hopper wgmma implementation or the fallback MMA implementation
* depending on `use_wgmma`. The selected GemmTensorOp::body_rs performs the region-tiled GEMM loop and
* updates the accumulator in-place.
* Dispatches at compile time to either the Hopper wgmma
* implementation or the fallback MMA implementation depending on
* `use_wgmma`. The selected GemmTensorOp::body_rs performs the
* region-tiled GEMM loop and updates the accumulator in-place.
*
* When `use_wgmma == true`, this function enforces wgmma constraints at compile time:
* When `use_wgmma == true`, this function enforces wgmma constraints
* at compile time:
* - A's leading dimension must equal (trans_A ? M : K)
* - B's leading dimension must equal (trans_B ? K : N)
* - offset_a and offset_b must be zero
*
* @param pA Pointer to operand A (global memory). Layout/stride expectations depend on template parameters.
* @param pB Pointer to operand B (base for shared-memory staging). Layout/stride expectations depend on template parameters.
* @param accum Pointer to the accumulator/output C buffer updated in-place.
* @param pA Pointer to operand A (global memory). Layout/stride
* expectations depend on template parameters.
* @param pB Pointer to operand B (base for shared-memory staging).
* Layout/stride expectations depend on template parameters.
* @param accum Pointer to the accumulator/output C buffer updated
* in-place.
*/
void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
void
gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
"Hopper wgmma doesn't support custom stride for A");
......@@ -723,17 +747,23 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE /**
* Perform a non-wgmma tiled GEMM where A regions are staged into shared memory
* and B is read directly from global memory, accumulating into `accum`.
* Perform a non-wgmma tiled GEMM where A regions are staged into
* shared memory and B is read directly from global memory,
* accumulating into `accum`.
*
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr implementation.
* Must be instantiated with `use_wgmma = false` (enforced via static_assert).
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr
* implementation. Must be instantiated with `use_wgmma = false`
* (enforced via static_assert).
*
* @param pA Pointer to the A operand in global memory (source that will be staged to shared memory).
* @param pB Pointer to the B operand in global memory (read directly).
* @param accum Pointer to the output accumulator matrix in global memory.
* @param pA Pointer to the A operand in global memory (source that
* will be staged to shared memory).
* @param pB Pointer to the B operand in global memory (read
* directly).
* @param accum Pointer to the output accumulator matrix in global
* memory.
*/
void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
void
gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
......@@ -742,13 +772,17 @@ void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
MMA::body_sr(pA, pB, accum);
}
template <int num_mma> TL_DEVICE /**
* Wait for all WMMA/MMA warps in the current warp-group to synchronize.
template <int num_mma>
TL_DEVICE /**
* Wait for all WMMA/MMA warps in the current warp-group to
* synchronize.
*
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes completes,
* ensuring all participating warps have arrived before proceeding.
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes
* completes, ensuring all participating warps have arrived before
* proceeding.
*/
void wait_wgmma() {
void
wait_wgmma() {
cute::warpgroup_wait<num_mma>();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment