Unverified Commit 7b1a0b7f authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

slice kv, and use 3d padding LDS layout (#15)

* slice kv, and use 3d padding LDS layout

* add missing sync

* put sync to another poace

* move sync place

* revert to normal
parent 6491acda
......@@ -101,6 +101,7 @@ int main(int argc, char* argv[])
constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kK1PerBlock = 32;
constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
......@@ -125,7 +126,8 @@ int main(int argc, char* argv[])
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock>{},
kN1PerBlock,
kK1PerBlock>{},
kGridSize,
kBlockSize,
0,
......
......@@ -34,7 +34,8 @@ template <typename QDataType,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock>
ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
struct BatchedGemmSoftmaxGemm
{
__device__ void operator()(const QDataType* q_ptr,
......@@ -89,7 +90,8 @@ struct BatchedGemmSoftmaxGemm
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock>{};
kN1PerBlock,
kK1PerBlock>{};
kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK,
......
......@@ -11,6 +11,7 @@
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
......@@ -32,7 +33,8 @@ template <typename QDataType,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock>
ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
struct GemmSoftmaxGemmImpl
{
// block gemm0 pipeline
......@@ -52,7 +54,7 @@ struct GemmSoftmaxGemmImpl
VDataType,
OaccDataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
......@@ -69,7 +71,7 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc;
}
#else
#elif 0
// fake XOR
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
......@@ -101,6 +103,34 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc_n_k;
}
#else
// 3d, with padding
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
// using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
Number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc;
}
#endif
__device__ static constexpr auto MakeVDramTileDistribution()
......@@ -111,7 +141,7 @@ struct GemmSoftmaxGemmImpl
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
......@@ -181,7 +211,7 @@ struct GemmSoftmaxGemmImpl
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}),
make_tuple(Number<kN1PerBlock>{}, Number<kK1PerBlock>{}),
{iN1, 0},
MakeVDramTileDistribution());
......@@ -191,7 +221,7 @@ struct GemmSoftmaxGemmImpl
MakeVLdsBlockDescriptor());
auto v_lds_window = make_tile_window(
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0});
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kK1PerBlock>{}), {0, 0});
// Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
......@@ -208,13 +238,16 @@ struct GemmSoftmaxGemmImpl
using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>,
SaccBlockTileType{}));
using PBlockTileType = decltype(
tile_elementwise_in(type_convert<PDataType, SaccDataType>, SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window));
using OaccBlockTileType = decltype(
gemm1(get_slice_tile(
PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window));
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
......@@ -239,6 +272,9 @@ struct GemmSoftmaxGemmImpl
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
// prefetch load v tile
const auto v_prefetch = load_tile(v_dram_window);
// m_local = rowmax(S{j})
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, Sequence<1>{}, f_max, NumericLimits<SMPLComputeDataType>::Lowest());
......@@ -292,34 +328,43 @@ struct GemmSoftmaxGemmImpl
});
});
block_sync_lds();
store_tile(v_lds_window, v_prefetch);
move_tile_window(v_dram_window, {0, kK1PerBlock});
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Block GEMM1: Oacc{j} += P{j} * V{j}
{
// load V{j}
const auto v = load_tile(v_dram_window);
// wait for gemm0 pipeline to finish
block_sync_lds();
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
store_tile(v_lds_window, v);
// wait for store_tile to finish
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
Sequence<0, i_k1 * kK1PerBlock>{},
Sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
block_sync_lds();
store_tile(v_lds_window, v);
move_tile_window(v_dram_window, {0, kK1PerBlock});
});
}
// tail
{
block_sync_lds();
// Oacc{j} += P{j} * V{j}
gemm1(o_acc, p, v_lds_window);
// wait for gemm1 to finish
gemm1(o_acc,
get_slice_tile(p,
Sequence<0, (k1_loops - 1) * kK1PerBlock>{},
Sequence<kM0PerBlock, kN0PerBlock>{}),
v_lds_window);
block_sync_lds();
}
// move tile windows
move_tile_window(k_dram_window, {kN0PerBlock, 0});
move_tile_window(v_dram_window, {0, kN0PerBlock});
iN0 += kN0PerBlock;
} while(iN0 < N0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment