Commit c5b5d2e4 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 9685fed2
#pragma once
#include "tuple.hpp"
#include "tensor_adaptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct BatchedGemmUtil
{
template <index_t MPerBlock, index_t NPerBlock>
static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01 = 1, index_t N01 = 1)
{
constexpr auto M1 = MPerBlock;
constexpr auto N1 = NPerBlock;
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -369,9 +369,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......
......@@ -187,9 +187,6 @@ struct DeviceGemmXdl
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
// AGridDesc_K0_M_K1,
// BGridDesc_K0_N_K1,
// CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......
......@@ -299,14 +299,9 @@ struct DeviceGemmXdlSplitK
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -363,7 +358,6 @@ struct DeviceGemmXdlSplitK
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
// index_t BatchStrideC_; // always zero
};
using GridwiseGemm =
......@@ -408,7 +402,7 @@ struct DeviceGemmXdlSplitK
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, 1, 1, 1, 1));
// Argument
struct Argument : public BaseArgument
......
......@@ -13,7 +13,6 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp"
#include "batched_gemm_util.hpp"
namespace ck {
namespace tensor_operation {
......@@ -370,6 +369,39 @@ struct DeviceGemmXdlSplitKCShuffle
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01, index_t N01)
{
constexpr auto M1 = MPerBlock;
constexpr auto N1 = NPerBlock;
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(const index_t BatchStrideA, const index_t BatchStrideB)
......@@ -443,8 +475,7 @@ struct DeviceGemmXdlSplitKCShuffle
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap =
decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, 1, 1, 1, 1));
struct Argument : public BaseArgument
{
......@@ -540,8 +571,11 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(
BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1));
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_,
c_grid_desc_m_n_.GetLength(I0),
c_grid_desc_m_n_.GetLength(I1),
1,
1);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment