Unverified Commit 7ccf0bb5 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

refactor gemm+softmax+gemm (#19)

* refactor gemm+softmax+gemm using block-gemm

* reorg files

* clean
parent 2dfbfbbc
......@@ -9,12 +9,14 @@
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
......@@ -46,95 +48,19 @@ struct GemmSoftmaxGemmImpl
ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>;
// block gemm1
using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1<
ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem<
using BlockGemm1 = ck::tile_program::block::BlockGemmARegBGmemCRegV1<
ck::tile_program::block::BlockGemmARegBGmemCRegProblem<
PDataType,
VDataType,
OaccDataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
// 2d
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_desc;
}
#else
// fake XOR
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_desc_d1_d2_d3,
make_tuple(make_xor_transform(make_tuple(kNPerBlock / 2, kKPerBlock), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto b_lds_desc_n_k = transform_tensor_descriptor(
b_lds_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(kNPerBlock / 2, 2)),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_desc_n_k;
}
#endif
__device__ static constexpr auto MakeVDramTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
using BDataType = VDataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
ck::tile_program::block::BlockGemmARegBGmemCRegV1DefaultPolicy>;
__device__ static constexpr ck::index_t GetStaticLdsSize()
{
using namespace ck;
return math::max(BlockGemm0Pipeline::GetStaticLdsSize(),
static_cast<index_t>(MakeVLdsBlockDescriptor().GetElementSpaceSize() *
sizeof(VDataType)));
return ck::math::max(BlockGemm0Pipeline::GetStaticLdsSize(),
BlockGemm1::GetStaticLdsSize());
}
__device__ void operator()(const QDataType* q_ptr,
......@@ -162,7 +88,7 @@ struct GemmSoftmaxGemmImpl
// allocate LDS
__shared__ char smem_ptr[GetStaticLdsSize()];
// Q/K/V DRAM and DRAM window
// Q/K/V DRAM
// FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1]
const auto q_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), Number<32>{}, Number<1>{});
......@@ -173,25 +99,15 @@ struct GemmSoftmaxGemmImpl
const auto v_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), Number<32>{}, Number<1>{});
// Q/K/V DRAM window
auto q_dram_window = make_tile_window(
q_dram, make_tuple(Number<kM0PerBlock>{}, Number<kK0PerBlock>{}), {iM0, 0});
auto k_dram_window = make_tile_window(
k_dram, make_tuple(Number<kN0PerBlock>{}, Number<kK0PerBlock>{}), {0, 0});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}),
{iN1, 0},
MakeVDramTileDistribution());
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
auto v_lds = make_tensor_view<AddressSpaceEnum::Lds>(reinterpret_cast<VDataType*>(smem_ptr),
MakeVLdsBlockDescriptor());
auto v_lds_window = make_tile_window(
v_lds, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0});
auto v_dram_window = make_tile_window(
v_dram, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {iN1, 0});
// Block GEMM0 pipeline and Block GEMM1
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
......@@ -214,7 +130,7 @@ struct GemmSoftmaxGemmImpl
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, Sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window));
using OaccBlockTileType = decltype(gemm1(PBlockTileType{}, v_dram_window, smem_ptr));
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
......@@ -286,7 +202,7 @@ struct GemmSoftmaxGemmImpl
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// but produce correct result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
......@@ -296,30 +212,19 @@ struct GemmSoftmaxGemmImpl
const auto p =
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
// Block GEMM1: Oacc{j} += P{j} * V{j}
{
// load V{j}
const auto v = load_tile(v_dram_window);
// wait for gemm0 pipeline to finish
block_sync_lds();
store_tile(v_lds_window, v);
// wait for store_tile to finish
// wait for gemm0 pipeline to finish reading Lds
block_sync_lds();
// Oacc{j} += P{j} * V{j}
gemm1(o_acc, p, v_lds_window);
// wait for gemm1 to finish
block_sync_lds();
}
// Block GEMM1: Oacc{j} += P{j} * V{j}
gemm1(o_acc, p, v_dram_window, smem_ptr);
// move tile windows
// move K/V tile windows for next iteration (J loop)
move_tile_window(k_dram_window, {kN0PerBlock, 0});
move_tile_window(v_dram_window, {0, kN0PerBlock});
// wait for gemm1 to finish reading Lds, before next iteration (J loop)
block_sync_lds();
iN0 += kN0PerBlock;
} while(iN0 < N0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmARegBGmemCReg
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBGmemCRegProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// A is block distributed tensor
// B is block window on global memory
// C is block distributed tensor
// This will:
// 1. Load B from global memory into shared memory and then
// 2. Call BlockGemmARegSGmemCRegV1
template <typename Problem, typename Policy = BlockGemmARegBGmemCRegV1DefaultPolicy>
struct BlockGemmARegBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using BlockGemmARegBSmemCRegImpl =
BlockGemmARegBSmemCRegV1<BlockGemmARegBSmemCRegProblem<ADataType,
BDataType,
CDataType,
kBlockSize,
BlockGemmShape>,
BlockGemmARegBSmemCRegV1DefaultPolicy>;
__host__ __device__ static constexpr ck::index_t GetStaticLdsSize()
{
return sizeof(BDataType) *
Policy::template MakeBSmemBlockDescriptor<Problem>().GetElementSpaceSize();
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockGmemWindowTmp>
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockGmemWindowTmp& b_block_gmem_window_tmp,
void* smem_ptr) const
{
static_assert(
is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>> &&
is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensor{}.GetLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t KPerBlock = ABlockTensor{}.GetLengths()[Number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
const auto b_block_gmem_window =
make_tile_window(b_block_gmem_window_tmp.GetBottomTensorView(),
make_tuple(Number<NPerBlock>{}, Number<KPerBlock>{}),
b_block_gmem_window_tmp.GetWindowOrigin(),
Policy::template MakeBGmemTileDistribution<Problem>());
// B LDS and LDS window
auto b_block_smem = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(smem_ptr),
Policy::template MakeBSmemBlockDescriptor<Problem>());
auto b_block_smem_window = make_tile_window(
b_block_smem, make_tuple(Number<MPerBlock>{}, Number<KPerBlock>{}), {0, 0});
// load B tile from global mem
const auto b_block_tile = load_tile(b_block_gmem_window);
// store B tile into shared mem
store_tile(b_block_smem_window, b_block_tile);
// wait for store_tile to finish
block_sync_lds();
// block GEMM
BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
}
// C = A * B
template <typename ABlockTensor, typename BBlockGmemWindowTmp>
__device__ auto operator()(const ABlockTensor& a_block_tensor,
const BBlockGmemWindowTmp& b_block_gmem_window_tmp,
void* smem_ptr) const
{
static_assert(is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensor{}.GetLengths()[Number<0>{}];
constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.GetWindowLengths()[Number<0>{}];
constexpr index_t KPerBlock = ABlockTensor{}.GetLengths()[Number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");
const auto b_block_gmem_window =
make_tile_window(b_block_gmem_window_tmp.GetBottomTensorView(),
make_tuple(Number<NPerBlock>{}, Number<KPerBlock>{}),
b_block_gmem_window_tmp.GetWindowOrigin(),
Policy::template MakeBGmemTileDistribution<Problem>());
// B LDS and LDS window
auto b_block_smem = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(smem_ptr),
Policy::template MakeBSmemBlockDescriptor<Problem>());
auto b_block_smem_window = make_tile_window(
b_block_smem, make_tuple(Number<MPerBlock>{}, Number<KPerBlock>{}), {0, 0});
// load B tile from global mem
const auto b_block_tile = load_tile(b_block_gmem_window);
// store B tile into shared mem
store_tile(b_block_smem_window, b_block_tile);
// wait for store_tile to finish
block_sync_lds();
// block GEMM
return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window);
}
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Default policy for BlockGemmARegBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmARegBGmemCRegV1DefaultPolicy
{
template <typename Problem>
__host__ __device__ static constexpr auto MakeBGmemTileDistribution()
{
using namespace ck;
using namespace ck::tile_program;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
StaticTileDistributionEncoding<Sequence<1>,
Tuple<Sequence<N0, N1, N2>, Sequence<K0, K1>>,
Tuple<Sequence<1>, Sequence<1, 2>>,
Tuple<Sequence<1>, Sequence<2, 0>>,
Sequence<1, 2>,
Sequence<0, 1>>{});
}
#if 0
// 2d
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#elif 0
// 3d + padding
template <typename Problem>
__host__ __device__ static constexpr auto MakeBSmemBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / 8>{}, Number<kNPerBlock>{}, Number<8>{}),
make_tuple(Number<(kNPerBlock + 1) * 8>{}, Number<8>{}, Number<1>{}),
Number<8>{},
Number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc;
}
#elif 1
// fake XOR
template <typename Problem>
__host__ __device__ static constexpr auto MakeBSmemBlockDescriptor()
{
using namespace ck;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(Number<kNPerBlock / 2>{}, Number<2>{}, Number<kKPerBlock>{}),
Number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(Number<kNPerBlock / 2>{}, Number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(Number<kNPerBlock / 2>{}, Number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
};
} // namespace block
} // namespace tile_program
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmARegBSmemCReg
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBSmemCRegProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace block
} // namespace tile_program
} // namespace ck
......@@ -13,28 +13,13 @@
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmARegBSmemCRegV1
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmARegBSmemCRegV1Problem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmASmemBSmemCRegV1
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmASmemBSmemCRegProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
} // namespace block
} // namespace tile_program
} // namespace ck
......@@ -14,28 +14,13 @@
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace ck {
namespace tile_program {
namespace block {
// Problem Description for BlockGemmASmemBSmemCRegV1
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
struct BlockGemmASmemBSmemCRegV1Problem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
};
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
......
......@@ -20,7 +20,7 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord>
__device__ auto load_tile(TileWindowWithStaticDistribution<BottomTensorView_,
__device__ auto load_tile(const TileWindowWithStaticDistribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment