Unverified Commit 95889861 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

support batch & nhead, and scale (#20)

* support batch & nhead

* support scale

* tile scheduler

* rename tile-scheduler to tile-partitioner

* add some exp2 math

* fix a bug when chaning tile size
parent b7abe77a
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "reference_batched_gemm.hpp" #include "reference_batched_gemm.hpp"
#include "reference_batched_softmax.hpp" #include "reference_batched_softmax.hpp"
#include "fmha_fwd_kernel.hpp" #include "fmha_fwd_kernel.hpp"
#include "fmha_fwd_tile_partitioner.hpp"
#include "fmha_fwd_epilogue.hpp" #include "fmha_fwd_epilogue.hpp"
using QDataType = ck::half_t; using QDataType = ck::half_t;
...@@ -32,9 +33,12 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm ...@@ -32,9 +33,12 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck::half_t; using ODataType = ck::half_t;
using FmhaShape = // M0 N0 K0 N1 K1
ck::tile_program::TileFmhaShape<128 /*M0*/, 128 /*N0*/, 32 /*K0*/, 128 /*N1*/, 32 /*K1*/>; // using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>;
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>;
using FmhaShape = ck::tile_program::TileFmhaShape<128, 128, 32, 128, 32>;
using FmhaTilePartitioner = FmhaFwdTilePartitioner<FmhaShape>;
using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType, using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QDataType,
KDataType, KDataType,
VDataType, VDataType,
...@@ -48,51 +52,61 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QD ...@@ -48,51 +52,61 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QD
using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>; using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>; using FmhaEpilogue = FmhaFwdEpilogue<FmhaFwdEpilogueProblem<OaccDataType, ODataType>>;
using FmhaKernel = FmhaFwdKernel<FmhaPipeline, FmhaEpilogue>; using FmhaKernel = FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ck::index_t Batch = 16; // batch * nheads int do_validation = 1;
ck::index_t M0 = 3328; // seqlen_q ck::index_t batch = 2;
ck::index_t N0 = 4096; // seqlen_k ck::index_t nhead = 8;
ck::index_t K0 = 128; // hdim_q ck::index_t seqlen_q = 3328;
ck::index_t N1 = 128; // hdim_v ck::index_t seqlen_k = 4096;
ck::index_t hdim_q = 128;
ck::index_t hdim_v = 128;
if(argc == 6) float scale = .0f;
{
Batch = std::stoi(argv[1]);
M0 = std::stoi(argv[2]);
N0 = std::stoi(argv[3]);
K0 = std::stoi(argv[4]);
N1 = std::stoi(argv[5]);
}
std::array<ck::index_t, 3> q_lengths{Batch, M0, K0};
std::array<ck::index_t, 3> q_strides{M0 * K0, K0, 1};
std::array<ck::index_t, 3> k_lengths{Batch, N0, K0};
std::array<ck::index_t, 3> k_strides{N0 * K0, K0, 1};
std::array<ck::index_t, 3> v_lengths{Batch, N1, N0}; bool i_perm = true; // if true, will be batch * nhead * seqlen * hdim
std::array<ck::index_t, 3> v_strides{N1 * N0, N0, 1}; bool o_perm = true; // if false, will be batch * seqlen * nhead * hdim
std::array<ck::index_t, 3> s_lengths{Batch, M0, N0}; if(argc >= 2)
std::array<ck::index_t, 3> s_strides{M0 * N0, N0, 1}; do_validation = std::stoi(argv[1]);
std::array<ck::index_t, 3> p_lengths{Batch, M0, N0}; if(argc >= 8)
std::array<ck::index_t, 3> p_strides{M0 * N0, N0, 1}; {
batch = std::stoi(argv[2]);
std::array<ck::index_t, 3> o_lengths{Batch, M0, N1}; nhead = std::stoi(argv[3]);
std::array<ck::index_t, 3> o_strides{M0 * N1, N1, 1}; seqlen_q = std::stoi(argv[4]);
seqlen_k = std::stoi(argv[5]);
hdim_q = std::stoi(argv[6]);
hdim_v = std::stoi(argv[7]);
}
if(argc >= 9)
scale = std::stof(argv[8]);
if(argc >= 10)
i_perm = static_cast<bool>(std::stoi(argv[9]));
if(argc >= 11)
o_perm = static_cast<bool>(std::stoi(argv[10]));
if(scale == .0f)
scale = 1.0 / ck::math::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
auto get_lengths = [&](bool permute,
ck::index_t b /*batch*/,
ck::index_t h /*nhead*/,
ck::index_t s /*seqlen*/,
ck::index_t d /*hdim*/) {
if(permute)
return std::array<ck::index_t, 4>{b, h, s, d};
else
return std::array<ck::index_t, 4>{b, s, h, d};
};
// host verify // host verify
Tensor<QDataType> q_host(q_lengths, q_strides); Tensor<QDataType> q_host(get_lengths(i_perm, batch, nhead, seqlen_q, hdim_q));
Tensor<KDataType> k_host(k_lengths, k_strides); Tensor<KDataType> k_host(get_lengths(i_perm, batch, nhead, seqlen_k, hdim_q));
Tensor<VDataType> v_host(v_lengths, v_strides); Tensor<VDataType> v_host(get_lengths(i_perm, batch, nhead, hdim_v, seqlen_k));
Tensor<SMPLComputeDataType> s_host_ref(s_lengths, s_strides); Tensor<ODataType> o_host(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
Tensor<PDataType> p_host_ref(p_lengths, p_strides);
Tensor<ODataType> o_host_ref(o_lengths, o_strides);
Tensor<ODataType> o_host_dev(o_lengths, o_strides);
#if 0 #if 0
ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host); ck::utils::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f}(q_host);
...@@ -104,50 +118,50 @@ int main(int argc, char* argv[]) ...@@ -104,50 +118,50 @@ int main(int argc, char* argv[])
ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host); ck::utils::FillUniformDistribution<VDataType>{-3.f, 3.f}(v_host);
#endif #endif
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host, k_host, s_host_ref);
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(s_host_ref,
p_host_ref);
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host, o_host_ref);
DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize()); DeviceMem q_buf(sizeof(QDataType) * q_host.GetElementSpaceSize());
DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize()); DeviceMem k_buf(sizeof(KDataType) * k_host.GetElementSpaceSize());
DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize()); DeviceMem v_buf(sizeof(VDataType) * v_host.GetElementSpaceSize());
DeviceMem o_buf(sizeof(ODataType) * o_host_ref.GetElementSpaceSize()); DeviceMem o_buf(sizeof(ODataType) * o_host.GetElementSpaceSize());
q_buf.ToDevice(q_host.mData.data()); q_buf.ToDevice(q_host.mData.data());
k_buf.ToDevice(k_host.mData.data()); k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data()); v_buf.ToDevice(v_host.mData.data());
dim3 kGridSize = FmhaKernel::GridSize(Batch, M0, N1); dim3 kGridSize = FmhaKernel::GridSize(batch, nhead, seqlen_q, hdim_v);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize(); constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
std::cout << "batch:" << Batch << ", seqlen_q:" << M0 << ", seqlen_k:" << N0 std::cout << "batch:" << batch << ", nhead:" << nhead << ", seqlen_q:" << seqlen_q
<< ", hdim_q:" << K0 << ", hdim_v:" << N1 << ", grid_size " << kGridSize.x << ", seqlen_k:" << seqlen_k << ", hdim_q:" << hdim_q << ", hdim_v:" << hdim_v
<< ", scale:" << scale << ", i_perm:" << i_perm << ", o_perm:" << o_perm
<< ", grid_size " << kGridSize.x << "x" << kGridSize.y << "x" << kGridSize.z
<< std::endl; << std::endl;
constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD constexpr ck::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize; constexpr ck::index_t kWarpPerBlock = kBlockSize.x / warpSize;
constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; constexpr ck::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
// batch * nhead * seqlen * hdim or batch * seqlen * nhead * hdim
auto kargs = FmhaKernel::MakeKargs(q_buf.GetDeviceBuffer(), auto kargs = FmhaKernel::MakeKargs(q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
M0, // seqlen_q seqlen_q, // seqlen_q
N0, // seqlen_k seqlen_k, // seqlen_k
K0, // hdim_q hdim_q, // hdim_q
N1, // hdim_v hdim_v, // hdim_v
K0, // stride_q scale,
K0, // stride_k i_perm ? hdim_q : nhead * hdim_q, // stride_q
N0, // stride_v i_perm ? hdim_q : nhead * hdim_q, // stride_k
N1, // stride_o i_perm ? seqlen_k : nhead * seqlen_k, // stride_v
M0 * K0, // batch_stride_q o_perm ? hdim_v : nhead * hdim_v, // stride_o
N0 * K0, // batch_stride_k i_perm ? seqlen_q * hdim_q : hdim_q, // nhead_stride_q
N1 * N0, // batch_stride_v i_perm ? seqlen_k * hdim_q : hdim_q, // nhead_stride_k
M0 * N1); // batch_stride_o i_perm ? hdim_v * seqlen_k : seqlen_k, // nhead_stride_v
o_perm ? seqlen_q * hdim_v : hdim_v, // nhead_stride_o
nhead * seqlen_q * hdim_q, // batch_stride_q
nhead * seqlen_k * hdim_q, // batch_stride_k
nhead * hdim_v * seqlen_k, // batch_stride_v
nhead * seqlen_q * hdim_v); // batch_stride_o
float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true}, float ave_time = launch_kernel<kBlockSize.x, kBlockPerCu>(StreamConfig{nullptr, true},
FmhaKernel{}, FmhaKernel{},
...@@ -156,14 +170,13 @@ int main(int argc, char* argv[]) ...@@ -156,14 +170,13 @@ int main(int argc, char* argv[])
0, 0,
kargs); // BatchStrideO kargs); // BatchStrideO
o_buf.FromDevice(o_host_dev.mData.data()); std::size_t flop = std::size_t(2) * batch * nhead * seqlen_q * seqlen_k * hdim_q +
std::size_t(2) * batch * nhead * seqlen_q * hdim_v * seqlen_k;
std::size_t flop =
std::size_t(2) * Batch * M0 * N0 * K0 + std::size_t(2) * Batch * M0 * N1 * N0;
std::size_t num_btype = std::size_t num_btype = sizeof(QDataType) * batch * nhead * seqlen_q * hdim_q +
sizeof(QDataType) * Batch * M0 * K0 + sizeof(KDataType) * Batch * N0 * K0 + sizeof(KDataType) * batch * nhead * seqlen_k * hdim_q +
sizeof(VDataType) * Batch * N1 * N0 + sizeof(ODataType) * Batch * M0 * N1; sizeof(VDataType) * batch * nhead * hdim_v * seqlen_k +
sizeof(ODataType) * batch * nhead * seqlen_q * hdim_v;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -172,5 +185,49 @@ int main(int argc, char* argv[]) ...@@ -172,5 +185,49 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
return !ck::utils::check_err(o_host_dev, o_host_ref); if(do_validation)
{
Tensor<QDataType> q_host_ref({batch * nhead, seqlen_q, hdim_q});
Tensor<KDataType> k_host_ref({batch * nhead, seqlen_k, hdim_q});
Tensor<VDataType> v_host_ref({batch * nhead, hdim_v, seqlen_k});
Tensor<ODataType> o_host_ref({batch * nhead, seqlen_q, hdim_v});
Tensor<ODataType> o_host_result_ref(get_lengths(o_perm, batch, nhead, seqlen_q, hdim_v));
Tensor<SMPLComputeDataType> s_host_ref({batch * nhead, seqlen_q, seqlen_k});
Tensor<PDataType> p_host_ref({batch * nhead, seqlen_q, seqlen_k});
// clang-format off
// permute
if(i_perm) q_host.ForEach([&](auto& self, auto idx) { q_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else q_host.ForEach([&](auto& self, auto idx) { q_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
if(i_perm) k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else k_host.ForEach([&](auto& self, auto idx) { k_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
if(i_perm) v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]) = self(idx); });
else v_host.ForEach([&](auto& self, auto idx) { v_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]) = self(idx); });
// reference
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
q_host_ref, k_host_ref, s_host_ref,
[](const QDataType& x) { return x; },
[](const KDataType& x) { return x; },
[&scale](const SaccDataType& x) { return scale * x; });
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(s_host_ref,
p_host_ref);
reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref, v_host_ref, o_host_ref);
// permute
if(o_perm) o_host_result_ref.ForEach([&](auto& self, auto idx) { self(idx) = o_host_ref(idx[0] * nhead + idx[1], idx[2], idx[3]); });
else o_host_result_ref.ForEach([&](auto& self, auto idx) { self(idx) = o_host_ref(idx[0] * nhead + idx[2], idx[1], idx[3]); });
// clang-format on
o_buf.FromDevice(o_host.mData.data());
return !ck::utils::check_err(o_host, o_host_result_ref);
}
else
{
return 0;
}
} }
...@@ -11,9 +11,12 @@ ...@@ -11,9 +11,12 @@
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) // P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]
template <typename FmhaPipeline_, typename EpiloguePipeline_> #define C_LOG2E 1.44269504088896340736 // log2(e)
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel struct FmhaFwdKernel
{ {
using TilePartitioner = ck::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck::remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = ck::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck::index_t kBlockSize = FmhaPipeline::kBlockSize;
...@@ -33,10 +36,19 @@ struct FmhaFwdKernel ...@@ -33,10 +36,19 @@ struct FmhaFwdKernel
ck::index_t seqlen_k; ck::index_t seqlen_k;
ck::index_t hdim_q; ck::index_t hdim_q;
ck::index_t hdim_v; ck::index_t hdim_v;
float scale;
ck::index_t stride_q; ck::index_t stride_q;
ck::index_t stride_k; ck::index_t stride_k;
ck::index_t stride_v; ck::index_t stride_v;
ck::index_t stride_o; ck::index_t stride_o;
ck::index_t nhead_stride_q;
ck::index_t nhead_stride_k;
ck::index_t nhead_stride_v;
ck::index_t nhead_stride_o;
ck::index_t batch_stride_q; ck::index_t batch_stride_q;
ck::index_t batch_stride_k; ck::index_t batch_stride_k;
ck::index_t batch_stride_v; ck::index_t batch_stride_v;
...@@ -51,37 +63,33 @@ struct FmhaFwdKernel ...@@ -51,37 +63,33 @@ struct FmhaFwdKernel
ck::index_t seqlen_k, ck::index_t seqlen_k,
ck::index_t hdim_q, ck::index_t hdim_q,
ck::index_t hdim_v, ck::index_t hdim_v,
float scale,
ck::index_t stride_q, ck::index_t stride_q,
ck::index_t stride_k, ck::index_t stride_k,
ck::index_t stride_v, ck::index_t stride_v,
ck::index_t stride_o, ck::index_t stride_o,
ck::index_t nhead_stride_q,
ck::index_t nhead_stride_k,
ck::index_t nhead_stride_v,
ck::index_t nhead_stride_o,
ck::index_t batch_stride_q, ck::index_t batch_stride_q,
ck::index_t batch_stride_k, ck::index_t batch_stride_k,
ck::index_t batch_stride_v, ck::index_t batch_stride_v,
ck::index_t batch_stride_o) ck::index_t batch_stride_o)
{ {
return Kargs{q_ptr, return Kargs{q_ptr, k_ptr, v_ptr, o_ptr, seqlen_q,
k_ptr, seqlen_k, hdim_q, hdim_v, scale, stride_q,
v_ptr, stride_k, stride_v, stride_o, nhead_stride_q, nhead_stride_k,
o_ptr, nhead_stride_v, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
stride_q,
stride_k,
stride_v,
stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_o}; batch_stride_o};
} }
__host__ static constexpr auto __host__ static constexpr auto GridSize(ck::index_t batch_size_,
GridSize(ck::index_t batch_size_, ck::index_t seqlen_q_, ck::index_t hdim_v_) ck::index_t nhead_,
ck::index_t seqlen_q_,
ck::index_t hdim_v_)
{ {
return dim3(batch_size_ * (seqlen_q_ / FmhaPipeline::kM0) * (hdim_v_ / FmhaPipeline::kN1)); return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -101,34 +109,21 @@ struct FmhaFwdKernel ...@@ -101,34 +109,21 @@ struct FmhaFwdKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
// divide problem // divide problem
const index_t num_tile_m0 = kargs.seqlen_q / FmhaPipeline::kM0; const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
const index_t num_tile_n1 = kargs.hdim_v / FmhaPipeline::kN1; TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t id_block = ck::get_block_id();
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck::make_tuple(quotient, modulus);
};
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
const index_t i_batch = __builtin_amdgcn_readfirstlane(id_tile_batch); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(id_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(id_tile_n * FmhaPipeline::kN1);
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
reinterpret_cast<const QDataType*>(kargs.q_ptr) + i_batch * kargs.batch_stride_q; i_nhead * kargs.nhead_stride_q + i_batch * kargs.batch_stride_q;
const KDataType* k_ptr = const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
reinterpret_cast<const KDataType*>(kargs.k_ptr) + i_batch * kargs.batch_stride_k; i_nhead * kargs.nhead_stride_k + i_batch * kargs.batch_stride_k;
const VDataType* v_ptr = const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
reinterpret_cast<const VDataType*>(kargs.v_ptr) + i_batch * kargs.batch_stride_v; i_nhead * kargs.nhead_stride_v + i_batch * kargs.batch_stride_v;
ODataType* o_ptr = ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
reinterpret_cast<ODataType*>(kargs.o_ptr) + i_batch * kargs.batch_stride_o; i_nhead * kargs.nhead_stride_o + i_batch * kargs.batch_stride_o;
// Q/K/V DRAM and DRAM window // Q/K/V DRAM and DRAM window
// FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k], // FIXME: assume layout Q[seqlen_q, hdim_q], K[seqlen_k, hdim_q], V[hdim_v, seqlen_k],
...@@ -169,6 +164,7 @@ struct FmhaFwdKernel ...@@ -169,6 +164,7 @@ struct FmhaFwdKernel
auto o_acc_tile = FmhaPipeline{}(q_dram_window, auto o_acc_tile = FmhaPipeline{}(q_dram_window,
k_dram_window, k_dram_window,
v_dram_window, v_dram_window,
kargs.scale,
kargs.seqlen_k / FmhaPipeline::kN0, kargs.seqlen_k / FmhaPipeline::kN0,
kargs.hdim_q / FmhaPipeline::kK0, kargs.hdim_q / FmhaPipeline::kK0,
smem_ptr); smem_ptr);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner
{
using BlockFmhaShape = ck::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck::index_t batch_size_,
ck::index_t nhead_,
ck::index_t seqlen_q_,
ck::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3((seqlen_q_ / kM0) * (hdim_v_ / kN1), batch_size_, nhead_);
}
__device__ auto operator()(ck::index_t /*seqlen_q*/, ck::index_t hdim_v)
{
using namespace ck;
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = hdim_v / kN1;
const index_t i_block = blockIdx.x;
const index_t i_batch = blockIdx.y;
const index_t i_nhead = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
...@@ -6,10 +6,19 @@ ...@@ -6,10 +6,19 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename AElementOp,
typename BElementOp,
typename ACCElementOp>
void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k, void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k,
const Tensor<BDataType>& b_b_n_k, const Tensor<BDataType>& b_b_n_k,
Tensor<CDataType>& c_b_m_n) Tensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op,
const BElementOp& b_element_op,
const ACCElementOp& acc_element_op)
{ {
const int N = b_b_n_k.mDesc.GetLengths()[1]; const int N = b_b_n_k.mDesc.GetLengths()[1];
const int K = b_b_n_k.mDesc.GetLengths()[2]; const int K = b_b_n_k.mDesc.GetLengths()[2];
...@@ -21,16 +30,30 @@ void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k, ...@@ -21,16 +30,30 @@ void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k,
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ADataType v_a = a_b_m_k(batch, m, k); ADataType v_a = a_element_op(a_b_m_k(batch, m, k));
BDataType v_b = b_b_n_k(batch, n, k); BDataType v_b = b_element_op(b_b_n_k(batch, n, k));
v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
c_b_m_n(batch, m, n) = ck::type_convert<CDataType>(v_acc); c_b_m_n(batch, m, n) = ck::type_convert<CDataType>(acc_element_op(v_acc));
} }
}; };
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])( make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k,
const Tensor<BDataType>& b_b_n_k,
Tensor<CDataType>& c_b_m_n)
{
reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_b_m_k,
b_b_n_k,
c_b_m_n,
[](const ADataType& x) { return x; },
[](const BDataType& x) { return x; },
[](const AccDataType& x) { return x; });
}
...@@ -63,6 +63,7 @@ struct BlockFmhaPipelineQKVS ...@@ -63,6 +63,7 @@ struct BlockFmhaPipelineQKVS
const KElementFunction& k_element_func, const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
float scale,
index_t num_total_loop, index_t num_total_loop,
index_t num_sub_loop_qk, index_t num_sub_loop_qk,
void* smem_ptr) const void* smem_ptr) const
...@@ -215,6 +216,8 @@ struct BlockFmhaPipelineQKVS ...@@ -215,6 +216,8 @@ struct BlockFmhaPipelineQKVS
} }
// STAGE 2, scale softmax // STAGE 2, scale softmax
tile_elementwise_inout([&scale](auto& x) { x = x * scale; }, s_acc);
const auto s = const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc); // S{j} tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>( auto m_local = block_tile_reduce<SMPLComputeDataType>(
...@@ -245,11 +248,12 @@ struct BlockFmhaPipelineQKVS ...@@ -245,11 +248,12 @@ struct BlockFmhaPipelineQKVS
block_tile_reduce_sync(rowsum_p, f_sum); block_tile_reduce_sync(rowsum_p, f_sum);
// l{j}, Oacc{j} // l{j}, Oacc{j}
sweep_tile_span(p_spans[Number<0>{}], [&](auto idx0) { constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[Number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
const auto tmp = math::exp(m_old[i_idx] - m[i_idx]); const auto tmp = math::exp(m_old[i_idx] - m[i_idx]);
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(p_spans[Number<1>{}], [&](auto idx1) { sweep_tile_span(o_spans[Number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper, // FIXME: this use different equation from FA v2 paper,
// but produce correc result. // but produce correc result.
...@@ -319,6 +323,7 @@ struct BlockFmhaPipelineQKVS ...@@ -319,6 +323,7 @@ struct BlockFmhaPipelineQKVS
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
float scale,
index_t num_total_loop, index_t num_total_loop,
index_t num_sub_loop_qk, index_t num_sub_loop_qk,
void* smem_ptr) const void* smem_ptr) const
...@@ -330,6 +335,7 @@ struct BlockFmhaPipelineQKVS ...@@ -330,6 +335,7 @@ struct BlockFmhaPipelineQKVS
[](const KDataType& x) { return x; }, [](const KDataType& x) { return x; },
v_dram_block_window_tmp, v_dram_block_window_tmp,
[](const VDataType& x) { return x; }, [](const VDataType& x) { return x; },
scale,
num_total_loop, num_total_loop,
num_sub_loop_qk, num_sub_loop_qk,
smem_ptr); smem_ptr);
......
...@@ -183,6 +183,37 @@ inline __host__ double exp<double>(double x) ...@@ -183,6 +183,37 @@ inline __host__ double exp<double>(double x)
return std::exp(x); return std::exp(x);
} }
// prevent implicit type casting
template <typename T>
__host__ T exp2(T x);
template <typename T>
__device__ T exp2(T x);
template <>
inline __device__ float exp2<float>(float x)
{
return exp2f(x);
}
template <>
inline __device__ double exp2<double>(double x)
{
return exp2(x);
}
template <>
inline __host__ float exp2<float>(float x)
{
return std::exp2f(x);
}
template <>
inline __host__ double exp2<double>(double x)
{
return std::exp2l(x); // TODO: std does not have exp2 for double till c++23
}
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment