Unverified Commit 2dfbfbbc authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Revert "slice kv, and use 3d padding LDS layout (#15)" (#18)

This reverts commit 7b1a0b7f.
parent 9f36ac7c
......@@ -101,7 +101,6 @@ int main(int argc, char* argv[])
constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kK1PerBlock = 32;
constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
......@@ -126,8 +125,7 @@ int main(int argc, char* argv[])
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{},
kN1PerBlock>{},
kGridSize,
kBlockSize,
0,
......
......@@ -34,8 +34,7 @@ template <typename QDataType,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
ck::index_t kN1PerBlock>
struct BatchedGemmSoftmaxGemm
{
__device__ void operator()(const QDataType* q_ptr,
......@@ -90,8 +89,7 @@ struct BatchedGemmSoftmaxGemm
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{};
kN1PerBlock>{};
kernel_impl(q_ptr + iBatch * BatchStrideQ,
k_ptr + iBatch * BatchStrideK,
......
......@@ -11,7 +11,6 @@
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
......@@ -33,8 +32,7 @@ template <typename QDataType,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
ck::index_t kN1PerBlock>
struct GemmSoftmaxGemmImpl
{
// block gemm0 pipeline
......@@ -54,7 +52,7 @@ struct GemmSoftmaxGemmImpl
VDataType,
OaccDataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
......@@ -71,7 +69,7 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc;
}
#elif 0
#else
// fake XOR
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
......@@ -103,34 +101,6 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc_n_k;
}
#else
// 3d, with padding
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
// using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
Number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc;
}
#endif
__device__ static constexpr auto MakeVDramTileDistribution()
......@@ -141,7 +111,7 @@ struct GemmSoftmaxGemmImpl
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
......@@ -211,7 +181,7 @@ struct GemmSoftmaxGemmImpl
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(Number<kN1PerBlock>{}, Number<kK1PerBlock>{}),
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}),
{iN1, 0},
MakeVDramTileDistribution());
......@@ -221,7 +191,7 @@ struct GemmSoftmaxGemmImpl
MakeVLdsBlockDescriptor());
auto v_lds_window = make_tile_window(
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kK1PerBlock>{}), {0, 0});
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0});
// Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
......@@ -244,10 +214,7 @@ struct GemmSoftmaxGemmImpl
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(
get_slice_tile(
PBlockTileType{}, Sequence<0, 0>{}, Sequence<kM0PerBlock, kK1PerBlock>{}),
v_dram_window));
using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window));
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
......@@ -272,9 +239,6 @@ struct GemmSoftmaxGemmImpl
const auto s =
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
// prefetch load v tile
const auto v_prefetch = load_tile(v_dram_window);
// m_local = rowmax(S{j})
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s, Sequence<1>{}, f_max, NumericLimits<SMPLComputeDataType>::Lowest());
......@@ -322,55 +286,45 @@ struct GemmSoftmaxGemmImpl
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correct result.
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
store_tile(v_lds_window, v_prefetch);
move_tile_window(v_dram_window, {0, kK1PerBlock});
// type cast Pcompute{j} into P{j}
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Oacc{j}
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
if constexpr(k1_loops > 1)
// Block GEMM1: Oacc{j} += P{j} * V{j}
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
Sequence<0, i_k1 * kK1PerBlock>{},
Sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
// load V{j}
const auto v = load_tile(v_dram_window);
// wait for gemm0 pipeline to finish
block_sync_lds();
store_tile(v_lds_window, v);
move_tile_window(v_dram_window, {0, kK1PerBlock});
});
}
// tail
{
// wait for store_tile to finish
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
Sequence<0, (k1_loops - 1) * kK1PerBlock>{},
Sequence<kM0PerBlock, kN0PerBlock>{}),
v_lds_window);
// Oacc{j} += P{j} * V{j}
gemm1(o_acc, p, v_lds_window);
// wait for gemm1 to finish
block_sync_lds();
}
// move tile windows
move_tile_window(k_dram_window, {kN0PerBlock, 0});
move_tile_window(v_dram_window, {0, kN0PerBlock});
iN0 += kN0PerBlock;
} while(iN0 < N0);
// Oacc
// O
constexpr auto o_spans = decltype(o_acc)::GetDistributedSpans();
sweep_tile_span(o_spans[I0], [&](auto idx0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment