Unverified Commit 7b1a0b7f authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

slice kv, and use 3d padding LDS layout (#15)

* slice kv, and use 3d padding LDS layout

* add missing sync

* put sync to another poace

* move sync place

* revert to normal
parent 6491acda
...@@ -101,6 +101,7 @@ int main(int argc, char* argv[]) ...@@ -101,6 +101,7 @@ int main(int argc, char* argv[])
constexpr ck::index_t kN0PerBlock = 128; constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32; constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128; constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kK1PerBlock = 32;
constexpr ck::index_t kBlockSize = 256; constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
...@@ -125,7 +126,8 @@ int main(int argc, char* argv[]) ...@@ -125,7 +126,8 @@ int main(int argc, char* argv[])
kM0PerBlock, kM0PerBlock,
kN0PerBlock, kN0PerBlock,
kK0PerBlock, kK0PerBlock,
kN1PerBlock>{}, kN1PerBlock,
kK1PerBlock>{},
kGridSize, kGridSize,
kBlockSize, kBlockSize,
0, 0,
......
...@@ -34,7 +34,8 @@ template <typename QDataType, ...@@ -34,7 +34,8 @@ template <typename QDataType,
ck::index_t kM0PerBlock, ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock, ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock, ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock> ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
struct BatchedGemmSoftmaxGemm struct BatchedGemmSoftmaxGemm
{ {
__device__ void operator()(const QDataType* q_ptr, __device__ void operator()(const QDataType* q_ptr,
...@@ -89,7 +90,8 @@ struct BatchedGemmSoftmaxGemm ...@@ -89,7 +90,8 @@ struct BatchedGemmSoftmaxGemm
kM0PerBlock, kM0PerBlock,
kN0PerBlock, kN0PerBlock,
kK0PerBlock, kK0PerBlock,
kN1PerBlock>{}; kN1PerBlock,
kK1PerBlock>{};
kernel_impl(q_ptr + iBatch * BatchStrideQ, kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK, k_ptr + iBatch * BatchStrideK,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tile_program/tile/tile_distribution.hpp" #include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp" #include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp" #include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp" #include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp" #include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
...@@ -32,7 +33,8 @@ template <typename QDataType, ...@@ -32,7 +33,8 @@ template <typename QDataType,
ck::index_t kM0PerBlock, ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock, ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock, ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock> ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
struct GemmSoftmaxGemmImpl struct GemmSoftmaxGemmImpl
{ {
// block gemm0 pipeline // block gemm0 pipeline
...@@ -52,7 +54,7 @@ struct GemmSoftmaxGemmImpl ...@@ -52,7 +54,7 @@ struct GemmSoftmaxGemmImpl
VDataType, VDataType,
OaccDataType, OaccDataType,
kBlockSize, kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>, ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>; ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0 #if 0
...@@ -69,7 +71,7 @@ struct GemmSoftmaxGemmImpl ...@@ -69,7 +71,7 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc; return b_lds_desc;
} }
#else #elif 0
// fake XOR // fake XOR
__device__ static constexpr auto MakeVLdsBlockDescriptor() __device__ static constexpr auto MakeVLdsBlockDescriptor()
{ {
...@@ -101,6 +103,34 @@ struct GemmSoftmaxGemmImpl ...@@ -101,6 +103,34 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc_n_k; return b_lds_desc_n_k;
} }
#else
// 3d, with padding
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
// using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
Number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc;
}
#endif #endif
__device__ static constexpr auto MakeVDramTileDistribution() __device__ static constexpr auto MakeVDramTileDistribution()
...@@ -111,7 +141,7 @@ struct GemmSoftmaxGemmImpl ...@@ -111,7 +141,7 @@ struct GemmSoftmaxGemmImpl
using BDataType = VDataType; using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock; constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock; constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType); constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -181,7 +211,7 @@ struct GemmSoftmaxGemmImpl ...@@ -181,7 +211,7 @@ struct GemmSoftmaxGemmImpl
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram, make_tile_window(v_dram,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), make_tuple(Number<kN1PerBlock>{}, Number<kK1PerBlock>{}),
{iN1, 0}, {iN1, 0},
MakeVDramTileDistribution()); MakeVDramTileDistribution());
...@@ -191,7 +221,7 @@ struct GemmSoftmaxGemmImpl ...@@ -191,7 +221,7 @@ struct GemmSoftmaxGemmImpl
MakeVLdsBlockDescriptor()); MakeVLdsBlockDescriptor());
auto v_lds_window = make_tile_window( auto v_lds_window = make_tile_window(
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0}); v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kK1PerBlock>{}), {0, 0});
// Block GEMM0 pipeline and Block GEMM1 // Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{}; constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
...@@ -208,13 +238,16 @@ struct GemmSoftmaxGemmImpl ...@@ -208,13 +238,16 @@ struct GemmSoftmaxGemmImpl
using SBlockTileType = decltype(tile_elementwise_in( using SBlockTileType = decltype(tile_elementwise_in(
type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{})); type_convert<SMPLComputeDataType, SaccDataType>, SaccBlockTileType{}));
using PBlockTileType = decltype(tile_elementwise_in(type_convert<PDataType, SaccDataType>, using PBlockTileType = decltype(
SaccBlockTileType{})); tile_elementwise_in(type_convert<PDataType, SaccDataType>, SaccBlockTileType{}));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>( using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0})); SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window)); using OaccBlockTileType = decltype(
gemm1(get_slice_tile(
PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window));
// init Oacc, M, L // init Oacc, M, L
auto o_acc = OaccBlockTileType{}; auto o_acc = OaccBlockTileType{};
...@@ -239,6 +272,9 @@ struct GemmSoftmaxGemmImpl ...@@ -239,6 +272,9 @@ struct GemmSoftmaxGemmImpl
const auto s = const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc); tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
// prefetch load v tile
const auto v_prefetch = load_tile(v_dram_window);
// m_local = rowmax(S{j}) // m_local = rowmax(S{j})
auto m_local = block_tile_reduce<SMPLComputeDataType>( auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, Sequence<1>{}, f_max, NumericLimits<SMPLComputeDataType>::Lowest()); s, Sequence<1>{}, f_max, NumericLimits<SMPLComputeDataType>::Lowest());
...@@ -292,34 +328,43 @@ struct GemmSoftmaxGemmImpl ...@@ -292,34 +328,43 @@ struct GemmSoftmaxGemmImpl
}); });
}); });
block_sync_lds();
store_tile(v_lds_window, v_prefetch);
move_tile_window(v_dram_window, {0, kK1PerBlock});
// type cast Pcompute{j} into P{j} // type cast Pcompute{j} into P{j}
const auto p = const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute); tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Block GEMM1: Oacc{j} += P{j} * V{j} constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
{
// load V{j}
const auto v = load_tile(v_dram_window);
// wait for gemm0 pipeline to finish
block_sync_lds();
store_tile(v_lds_window, v); if constexpr(k1_loops > 1)
{
// wait for store_tile to finish static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
Sequence<0, i_k1 * kK1PerBlock>{},
Sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
block_sync_lds();
store_tile(v_lds_window, v);
move_tile_window(v_dram_window, {0, kK1PerBlock});
});
}
// tail
{
block_sync_lds(); block_sync_lds();
gemm1(o_acc,
// Oacc{j} += P{j} * V{j} get_slice_tile(p,
gemm1(o_acc, p, v_lds_window); Sequence<0, (k1_loops - 1) * kK1PerBlock>{},
Sequence<kM0PerBlock, kN0PerBlock>{}),
// wait for gemm1 to finish v_lds_window);
block_sync_lds(); block_sync_lds();
} }
// move tile windows // move tile windows
move_tile_window(k_dram_window, {kN0PerBlock, 0}); move_tile_window(k_dram_window, {kN0PerBlock, 0});
move_tile_window(v_dram_window, {0, kN0PerBlock});
iN0 += kN0PerBlock; iN0 += kN0PerBlock;
} while(iN0 < N0); } while(iN0 < N0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment