"tests/pipelines/animatediff/__init__.py" did not exist on "d03c9099bc690870f151035393484c3f5dea2d80"
Commit 14dc326e authored by Tri Dao's avatar Tri Dao
Browse files

Use Cutlass gemm as WarpMma

parent e78e7c95
...@@ -29,6 +29,13 @@ ...@@ -29,6 +29,13 @@
#include <fmha/utils.h> #include <fmha/utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/layout/layout.h"
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
namespace fmha { namespace fmha {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -247,6 +254,49 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) ...@@ -247,6 +254,49 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Element = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
using FragmentA = typename WarpMma::FragmentA;
using FragmentB = typename WarpMma::FragmentB;
using FragmentC = typename WarpMma::FragmentC;
static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
WarpMma mma_op;
mma_op(c_cl, a_cl, b_cl, c_cl);
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
for (int mi = 0; mi < M; mi++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
#pragma unroll
for (int i =0; i < 8; i++) {
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< template<
// The number of rows in the CTA tile. // The number of rows in the CTA tile.
int M_, int M_,
......
...@@ -408,9 +408,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -408,9 +408,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
smem_do.load(frag_do[ki & 1], ki); smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) { if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki); smem_v.load(frag_v[ki & 1], ki);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else { } else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
} }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...@@ -424,9 +424,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -424,9 +424,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) { if (!Kernel_traits::V_IN_REGS) {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else { } else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
} }
} }
...@@ -515,14 +515,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -515,14 +515,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki); smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dq::MMAS_K; int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
} }
static_assert(Gmem_tile_dq::LOOPS == 1); static_assert(Gmem_tile_dq::LOOPS == 1);
...@@ -555,13 +555,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -555,13 +555,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki); smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dkv::MMAS_K; int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
} }
// __syncthreads(); // __syncthreads();
...@@ -613,13 +613,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -613,13 +613,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki); smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dkv::MMAS_K; int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Make sure dQ is in shared memory. // Make sure dQ is in shared memory.
......
...@@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
// Do this part of O = P^T * V^T. // Do this part of O = P^T * V^T.
#pragma unroll #pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]);
} }
// The mapping from tidx to rows changes between the softmax and the O-reduction. // The mapping from tidx to rows changes between the softmax and the O-reduction.
......
...@@ -383,9 +383,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -383,9 +383,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_do.load(frag_do[ki & 1], ki); smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) { if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki); smem_v.load(frag_v[ki & 1], ki);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else { } else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
} }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...@@ -399,9 +399,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -399,9 +399,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) { if (!Kernel_traits::V_IN_REGS) {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else { } else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
} }
} }
...@@ -484,14 +484,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -484,14 +484,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki); smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dq::MMAS_K; int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
} }
static_assert(Gmem_tile_dq::LOOPS == 1); static_assert(Gmem_tile_dq::LOOPS == 1);
...@@ -524,13 +524,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -524,13 +524,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki); smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dkv::MMAS_K; int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
} }
// __syncthreads(); // __syncthreads();
...@@ -579,13 +579,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -579,13 +579,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki); smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_dkv::MMAS_K; int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
} }
// Make sure dQ is in shared memory. // Make sure dQ is in shared memory.
......
...@@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> { ...@@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
} }
} }
...@@ -175,12 +175,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> { ...@@ -175,12 +175,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki); Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers. // Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
} }
// Do the final stage of math. // Do the final stage of math.
{ {
int ki = Mma_tile_p::MMAS_K; int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
} }
} }
...@@ -497,7 +497,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -497,7 +497,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of O = P^T * V^T. // Do this part of O = P^T * V^T.
#pragma unroll #pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]);
} }
// The mapping from tidx to rows changes between the softmax and the O-reduction. // The mapping from tidx to rows changes between the softmax and the O-reduction.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment