Commit 9f65d608 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

[CK TILE] Gemm pk_int4_t permute B

parent 0328b06e
...@@ -35,6 +35,71 @@ ...@@ -35,6 +35,71 @@
#error "unsupported CK_TILE_PIPELINE_DEFAULT value" #error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif #endif
struct GemmBasicConfig
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
#endif
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType> template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
......
...@@ -29,8 +29,32 @@ auto calculate_rtol_atol(const ck_tile::index_t K, ...@@ -29,8 +29,32 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
// Use higher threshold // Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
} }
template <typename Tensor> template <typename Tensor>
void permute_tensor_b(Tensor& tensor) void permute_tensor_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
const ck_tile::index_t K1 = GemmBasicConfig::K_Tile;
const ck_tile::index_t K0 = K / GemmBasicConfig::K_Tile;
Tensor tensor_copy = tensor;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj));
}
}
}
}
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{ {
const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1); const ck_tile::index_t N = tensor.get_length(1);
...@@ -183,8 +207,8 @@ int run_gemm_example_with_layouts(int argc, ...@@ -183,8 +207,8 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0) if(init_method == 0)
{ {
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
} }
else if(init_method == 1) else if(init_method == 1)
{ {
...@@ -206,18 +230,29 @@ int run_gemm_example_with_layouts(int argc, ...@@ -206,18 +230,29 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data()); static_assert(!GemmBasicConfig::PermuteA, "Not implemented");
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>) if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{ {
// Permute data for device implementation // Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n; ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
permute_tensor_b(b_k_n_dev); if constexpr(GemmBasicConfig::PermuteB)
{
permute_tensor_b(b_k_n_dev);
}
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
} }
else else
{ {
if constexpr(GemmBasicConfig::PermuteB)
{
std::cout << "Permute for this DataType is not implemented." << std::endl;
return false;
}
b_k_n_dev_buf.ToDevice(b_k_n.data()); b_k_n_dev_buf.ToDevice(b_k_n.data());
} }
a_m_k_dev_buf.ToDevice(a_m_k.data());
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
......
...@@ -21,90 +21,42 @@ template <typename ADataType, ...@@ -21,90 +21,42 @@ template <typename ADataType,
typename CLayout> typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) using GemmShape = ck_tile::TileGemmShape<
// Memory friendly for Interwave scheduler ck_tile::
constexpr ck_tile::index_t M_Tile = 128; sequence<GemmBasicConfig::M_Tile, GemmBasicConfig::N_Tile, GemmBasicConfig::K_Tile>,
constexpr ck_tile::index_t N_Tile = 32; ck_tile::
constexpr ck_tile::index_t K_Tile = 64; sequence<GemmBasicConfig::M_Warp, GemmBasicConfig::N_Warp, GemmBasicConfig::K_Warp>,
ck_tile::sequence<GemmBasicConfig::M_Warp_Tile,
constexpr ck_tile::index_t M_Warp = 4; GemmBasicConfig::N_Warp_Tile,
constexpr ck_tile::index_t N_Warp = 1; GemmBasicConfig::K_Warp_Tile>,
constexpr ck_tile::index_t K_Warp = 1; GemmBasicConfig::PermuteA,
GemmBasicConfig::PermuteB>;
constexpr ck_tile::index_t M_Warp_Tile = 32; using TilePartitioner =
constexpr ck_tile::index_t N_Warp_Tile = 32; ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
constexpr ck_tile::index_t K_Warp_Tile = 8; GemmBasicConfig::TileParitionerGroupNum,
GemmBasicConfig::TileParitionerM01>;
constexpr bool DoubleSmemBuffer = false;
#endif using Traits = ck_tile::TileGemmTraits<GemmBasicConfig::kPadM,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) GemmBasicConfig::kPadN,
// Compute friendly for Intrawave scheduler GemmBasicConfig::kPadK,
constexpr ck_tile::index_t M_Tile = 256; ALayout,
constexpr ck_tile::index_t N_Tile = 256; BLayout,
constexpr ck_tile::index_t K_Tile = 64; CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmBasicConfig::kPadM,
constexpr ck_tile::index_t M_Warp = 2; GemmBasicConfig::kPadN,
constexpr ck_tile::index_t N_Warp = 2; GemmBasicConfig::kPadK,
constexpr ck_tile::index_t K_Warp = 1; GemmBasicConfig::DoubleSmemBuffer,
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
// ===============================================
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
TransposeC>; GemmBasicConfig::TransposeC>;
using GemmPipelineProblem = using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>; using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * GemmBasicConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmBasicConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
...@@ -133,11 +85,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -133,11 +85,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
GemmPipelineProblem::kBlockSize, GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock, TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock, TilePartitioner::NPerBlock,
M_Warp, GemmBasicConfig::M_Warp,
N_Warp, GemmBasicConfig::N_Warp,
M_Warp_Tile, GemmBasicConfig::M_Warp_Tile,
N_Warp_Tile, GemmBasicConfig::N_Warp_Tile,
K_Warp_Tile, GemmBasicConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC>>; UniversalGemmProblem::TransposeC>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
...@@ -158,8 +110,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -158,8 +110,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
<< std::endl; << std::endl;
} }
ave_time = ck_tile::launch_kernel( ave_time =
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs)); ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmBasicConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time; return ave_time;
}; };
......
...@@ -279,6 +279,7 @@ struct GemmKernel ...@@ -279,6 +279,7 @@ struct GemmKernel
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset) const SplitKBatchOffset& splitk_batch_offset)
{ {
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() { const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -303,21 +304,63 @@ struct GemmKernel ...@@ -303,21 +304,63 @@ struct GemmKernel
const auto& b_tensor_view = [&]() { const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
return make_naive_tensor_view<address_space_enum::global>( if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
b_ptr, {
make_tuple(splitk_batch_offset.splitted_k, kargs.N), const index_t K1 = TilePartitioner::BlockGemmShape::kK;
make_tuple(kargs.stride_B, 1), const index_t K0 =
number<GemmPipeline::GetVectorSizeB()>{}, splitk_batch_offset.splitted_k / TilePartitioner::BlockGemmShape::kK;
number<1>{}); const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
} }
else else
{ {
return make_naive_tensor_view<address_space_enum::global>( if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
b_ptr, {
make_tuple(kargs.N, splitk_batch_offset.splitted_k), const index_t K1 = TilePartitioner::BlockGemmShape::kK;
make_tuple(kargs.stride_B, 1), const index_t K0 =
number<GemmPipeline::GetVectorSizeB()>{}, splitk_batch_offset.splitted_k / TilePartitioner::BlockGemmShape::kK;
number<1>{}); const auto b_k0_n_k1_desc =
make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
make_tuple(kargs.N * K1, K1, I1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
const auto b_n_k_desc = transform_tensor_descriptor(
b_k0_n_k1_desc,
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(kargs.N)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
} }
}(); }();
......
...@@ -8,7 +8,11 @@ ...@@ -8,7 +8,11 @@
namespace ck_tile { namespace ck_tile {
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_> template <typename BlockTile_,
typename BlockWarps_,
typename WarpTile_,
bool PermuteA_ = false,
bool PermuteB_ = false>
struct TileGemmShape struct TileGemmShape
{ {
using BlockTile = remove_cvref_t<BlockTile_>; using BlockTile = remove_cvref_t<BlockTile_>;
...@@ -21,6 +25,9 @@ struct TileGemmShape ...@@ -21,6 +25,9 @@ struct TileGemmShape
static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kN = BlockTile::at(number<1>{});
static constexpr index_t kK = BlockTile::at(number<2>{}); static constexpr index_t kK = BlockTile::at(number<2>{});
static constexpr bool PermuteA = PermuteA_;
static constexpr bool PermuteB = PermuteB_;
CK_TILE_HOST static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// clang-format off // clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment