Commit b27909a0 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent e542dfc4
...@@ -23,98 +23,7 @@ namespace ck { ...@@ -23,98 +23,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#if 0
template <typename GridwiseGemm,
typename GemmDesc,
typename GemmSharedArgs,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const GemmSharedArgs gemm_shared_args)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = block_id / gemm_shared_args.block_size;
#if 1
// const auto M = gemm_shared_args.M;
// const auto N = gemm_shared_args.N;
// const auto K = gemm_shared_args.K;
// const auto StrideA = gemm_shared_args.StrideA;
// const auto StrideB = gemm_shared_args.StrideB;
// const auto StrideC = gemm_shared_args.StrideC;
// const auto MPadded = gemm_shared_args.MPadded;
// const auto NPadded = gemm_shared_args.NPadded;
// const auto KPadded = gemm_shared_args.KPadded;
// const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch;
const auto M = 2;
const auto N = 768;
const auto K = 4608;
const auto StrideA = 4608;
const auto StrideB = 4608;
const auto StrideC = 768;
const auto MPadded = 32;
const auto NPadded = 768;
const auto KPadded = 4608;
const auto K0 = 576;
const auto k_batch = 1;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = gemm_shared_args.block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
const auto block_2_ctile_map = grouped_block_2_ctile_map;
#endif
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
#else
ignore = gemm_descs_const;
ignore = all_gemm_block_size;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 1
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
...@@ -124,88 +33,7 @@ __global__ void ...@@ -124,88 +33,7 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_size) const index_t group_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
for(index_t group_id = 0; group_id < group_size; group_id++)
{
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, 0);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
}
#else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 0
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
#if 0
const index_t N,
const index_t K,
const index_t StrideB,
const index_t NPadded,
const index_t KPadded,
const index_t K0,
const index_t k_batch,
#endif
const index_t block_size)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
...@@ -216,7 +44,6 @@ __global__ void ...@@ -216,7 +44,6 @@ __global__ void
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0; index_t left = 0;
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
...@@ -234,13 +61,7 @@ __global__ void ...@@ -234,13 +61,7 @@ __global__ void
} }
group_id = index_t((left + right) / 2); group_id = index_t((left + right) / 2);
} }
#else
const index_t group_id = block_id / block_size;
#endif
#if 1
#if 1
const auto M = gemm_desc_ptr[group_id].karg_.M; const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N; const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K; const auto K = gemm_desc_ptr[group_id].karg_.K;
...@@ -252,21 +73,11 @@ __global__ void ...@@ -252,21 +73,11 @@ __global__ void
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded; const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0; const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch; const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
#else
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
#endif
// const auto block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock(); static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock(); static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8; static constexpr index_t B2E_M01 = 8;
const index_t block_start = block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
...@@ -274,7 +85,7 @@ __global__ void ...@@ -274,7 +85,7 @@ __global__ void
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch}; const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, gemm_desc_ptr[group_id].block_start_);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid, gemm_desc_ptr[group_id].karg_.p_a_grid,
...@@ -293,20 +104,11 @@ __global__ void ...@@ -293,20 +104,11 @@ __global__ void
k_batch, k_batch,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
block_2_ctile_map); block_2_ctile_map);
#else
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_);
#endif
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -432,16 +234,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -432,16 +234,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
KernelArgument karg_; KernelArgument karg_;
// GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg, GemmTransKernelArg(KernelArgument&& karg,
// GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start, index_t block_start,
index_t block_end) index_t block_end)
: karg_{karg}, : karg_{karg},
// block_2_ctile_map_{b2c_map},
block_start_{block_start}, block_start_{block_start},
block_end_{block_end} block_end_{block_end}
{ {
...@@ -511,10 +310,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -511,10 +310,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = 0; const index_t block_start = grid_size_;
const index_t block_end = grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ = grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map // block-to-e-tile map
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
...@@ -567,10 +366,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -567,10 +366,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = 0; const index_t block_start = grid_size_;
const index_t block_end = grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ = grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map // block-to-e-tile map
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
...@@ -645,103 +444,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -645,103 +444,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
} }
struct ArgumentMsN1K1
{
const ADataType* p_a_grid;
const BDataType* p_b_grid;
EDataType* p_c_grid;
// index_t M;
// index_t StrideA;
// index_t StrideC;
// index_t MPadded;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
};
struct GemmTransKernelArgMsN1K1
{
ArgumentMsN1K1 karg_;
};
#if 1
std::vector<GemmTransKernelArgMsN1K1> gemm_kernel_args_msn1k1_;
// index_t all_gemm_block_size =
// arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
auto karg = ArgumentMsN1K1{
trans_arg.karg_.p_a_grid, trans_arg.karg_.p_b_grid, trans_arg.karg_.p_c_grid};
auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
<< " trans_arg.block_end_: " << trans_arg.block_end_
<< " block_size: " << block_size << std::endl;
gemm_kernel_args_msn1k1_.push_back({karg});
}
#endif
#if 1
hip_check_error(hipMemcpy(arg.p_workspace_, hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
#else
struct GemmSharedArgs
{
index_t block_size;
// index_t M;
// index_t N;
// index_t K;
// index_t StrideA;
// index_t StrideB;
// index_t StrideC;
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
// index_t K0;
// index_t k_batch;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
#if 0
void print()
{
std::cout << "block_size = " << block_size << " M = " << M << " N = " << N
<< " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC
<< " MPadded = " << MPadded << " NPadded = " << NPadded
<< " KPadded = " << KPadded << " K0 = " << K0
<< " k_batch = " << k_batch << std::endl;
}
#endif
};
auto shared_karg = GemmSharedArgs{
all_gemm_block_size,
// arg.gemm_kernel_args_[0].karg_.M,
// arg.gemm_kernel_args_[0].karg_.N,
// arg.gemm_kernel_args_[0].karg_.K,
// arg.gemm_kernel_args_[0].karg_.StrideA,
// arg.gemm_kernel_args_[0].karg_.StrideB,
// arg.gemm_kernel_args_[0].karg_.StrideC,
// arg.gemm_kernel_args_[0].karg_.MPadded,
// arg.gemm_kernel_args_[0].karg_.NPadded,
// arg.gemm_kernel_args_[0].karg_.KPadded,
// arg.gemm_kernel_args_[0].karg_.K0,
// arg.gemm_kernel_args_[0].karg_.k_batch,
// arg.gemm_kernel_args_[0].block_2_ctile_map_,
};
// shared_karg.print();
hip_check_error(
hipMemcpy(arg.p_workspace_,
gemm_kernel_args_msn1k1_.data(),
gemm_kernel_args_msn1k1_.size() * sizeof(GemmTransKernelArgMsN1K1),
hipMemcpyHostToDevice));
#endif
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
...@@ -762,25 +468,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -762,25 +468,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
#if 0
arg.gemm_kernel_args_[0].karg_.N,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideB,
arg.gemm_kernel_args_[0].karg_.NPadded,
arg.gemm_kernel_args_[0].karg_.KPadded,
arg.gemm_kernel_args_[0].karg_.K0,
arg.gemm_kernel_args_[0].karg_.k_batch,
#elif 0
all_gemm_block_size
#elif 1
arg.gemm_kernel_args_.size() arg.gemm_kernel_args_.size()
#endif
); );
}; };
std::cout << "all_have_main_k0_block_loop: " << all_have_main_k0_block_loop
<< " all_have_kbatch_gt_one: " << all_have_kbatch_gt_one << std::endl;
#if 1
if(all_have_main_k0_block_loop) if(all_have_main_k0_block_loop)
{ {
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
...@@ -827,58 +518,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -827,58 +518,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Run(kernel); Run(kernel);
} }
} }
#else
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
true,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
true,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
else
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
false,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArgMsN1K1,
GemmSharedArgs,
false,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
#endif
return ave_time; return ave_time;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment