Commit b27909a0 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent e542dfc4
......@@ -23,169 +23,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#if 0
template <typename GridwiseGemm,
typename GemmDesc,
typename GemmSharedArgs,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const GemmSharedArgs gemm_shared_args)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = block_id / gemm_shared_args.block_size;
#if 1
// const auto M = gemm_shared_args.M;
// const auto N = gemm_shared_args.N;
// const auto K = gemm_shared_args.K;
// const auto StrideA = gemm_shared_args.StrideA;
// const auto StrideB = gemm_shared_args.StrideB;
// const auto StrideC = gemm_shared_args.StrideC;
// const auto MPadded = gemm_shared_args.MPadded;
// const auto NPadded = gemm_shared_args.NPadded;
// const auto KPadded = gemm_shared_args.KPadded;
// const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch;
const auto M = 2;
const auto N = 768;
const auto K = 4608;
const auto StrideA = 4608;
const auto StrideB = 4608;
const auto StrideC = 768;
const auto MPadded = 32;
const auto NPadded = 768;
const auto KPadded = 4608;
const auto K0 = 576;
const auto k_batch = 1;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = gemm_shared_args.block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
const auto block_2_ctile_map = grouped_block_2_ctile_map;
#endif
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
#else
ignore = gemm_descs_const;
ignore = all_gemm_block_size;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 1
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_size)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
for(index_t group_id = 0; group_id < group_size; group_id++)
{
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, 0);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
}
#else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 0
template <typename GridwiseGemm,
typename GemmDesc,
......@@ -196,16 +33,7 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
#if 0
const index_t N,
const index_t K,
const index_t StrideB,
const index_t NPadded,
const index_t KPadded,
const index_t K0,
const index_t k_batch,
#endif
const index_t block_size)
const index_t group_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
......@@ -216,7 +44,6 @@ __global__ void
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
......@@ -234,13 +61,7 @@ __global__ void
}
group_id = index_t((left + right) / 2);
}
#else
const index_t group_id = block_id / block_size;
#endif
#if 1
#if 1
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K;
......@@ -252,21 +73,11 @@ __global__ void
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
#else
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
#endif
// const auto block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
......@@ -274,7 +85,7 @@ __global__ void
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, gemm_desc_ptr[group_id].block_start_);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
......@@ -293,20 +104,11 @@ __global__ void
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
#else
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
#endif
#else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#endif
template <typename ALayout,
typename BLayout,
......@@ -432,16 +234,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg
{
KernelArgument karg_;
// GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg,
// GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
// block_2_ctile_map_{b2c_map},
block_start_{block_start},
block_end_{block_end}
{
......@@ -511,10 +310,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = 0;
const index_t block_end = grid_size_grp;
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ = grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
......@@ -567,10 +366,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = 0;
const index_t block_end = grid_size_grp;
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ = grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
......@@ -645,103 +444,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
struct ArgumentMsN1K1
{
const ADataType* p_a_grid;
const BDataType* p_b_grid;
EDataType* p_c_grid;
// index_t M;
// index_t StrideA;
// index_t StrideC;
// index_t MPadded;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
};
struct GemmTransKernelArgMsN1K1
{
ArgumentMsN1K1 karg_;
};
#if 1
std::vector<GemmTransKernelArgMsN1K1> gemm_kernel_args_msn1k1_;
// index_t all_gemm_block_size =
// arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
auto karg = ArgumentMsN1K1{
trans_arg.karg_.p_a_grid, trans_arg.karg_.p_b_grid, trans_arg.karg_.p_c_grid};
auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
<< " trans_arg.block_end_: " << trans_arg.block_end_
<< " block_size: " << block_size << std::endl;
gemm_kernel_args_msn1k1_.push_back({karg});
}
#endif
#if 1
hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice));
#else
struct GemmSharedArgs
{
index_t block_size;
// index_t M;
// index_t N;
// index_t K;
// index_t StrideA;
// index_t StrideB;
// index_t StrideC;
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
// index_t K0;
// index_t k_batch;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
#if 0
void print()
{
std::cout << "block_size = " << block_size << " M = " << M << " N = " << N
<< " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC
<< " MPadded = " << MPadded << " NPadded = " << NPadded
<< " KPadded = " << KPadded << " K0 = " << K0
<< " k_batch = " << k_batch << std::endl;
}
#endif
};
auto shared_karg = GemmSharedArgs{
all_gemm_block_size,
// arg.gemm_kernel_args_[0].karg_.M,
// arg.gemm_kernel_args_[0].karg_.N,
// arg.gemm_kernel_args_[0].karg_.K,
// arg.gemm_kernel_args_[0].karg_.StrideA,
// arg.gemm_kernel_args_[0].karg_.StrideB,
// arg.gemm_kernel_args_[0].karg_.StrideC,
// arg.gemm_kernel_args_[0].karg_.MPadded,
// arg.gemm_kernel_args_[0].karg_.NPadded,
// arg.gemm_kernel_args_[0].karg_.KPadded,
// arg.gemm_kernel_args_[0].karg_.K0,
// arg.gemm_kernel_args_[0].karg_.k_batch,
// arg.gemm_kernel_args_[0].block_2_ctile_map_,
};
// shared_karg.print();
hip_check_error(
hipMemcpy(arg.p_workspace_,
gemm_kernel_args_msn1k1_.data(),
gemm_kernel_args_msn1k1_.size() * sizeof(GemmTransKernelArgMsN1K1),
hipMemcpyHostToDevice));
#endif
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
......@@ -762,25 +468,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
#if 0
arg.gemm_kernel_args_[0].karg_.N,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideB,
arg.gemm_kernel_args_[0].karg_.NPadded,
arg.gemm_kernel_args_[0].karg_.KPadded,
arg.gemm_kernel_args_[0].karg_.K0,
arg.gemm_kernel_args_[0].karg_.k_batch,
#elif 0
all_gemm_block_size
#elif 1
arg.gemm_kernel_args_.size()
#endif
);
};
std::cout << "all_have_main_k0_block_loop: " << all_have_main_k0_block_loop
<< " all_have_kbatch_gt_one: " << all_have_kbatch_gt_one << std::endl;
#if 1
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
......@@ -827,58 +518,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Run(kernel);
}
}
#else
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
true,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
true,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
else
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
false,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
false,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
#endif
return ave_time;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment