Commit e542dfc4 authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

test

parent 8bfacf9f
...@@ -23,7 +23,7 @@ namespace ck { ...@@ -23,7 +23,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#if 1 #if 0
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
typename GemmSharedArgs, typename GemmSharedArgs,
...@@ -114,7 +114,78 @@ __global__ void ...@@ -114,7 +114,78 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
#elif 1
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_size)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
for(index_t group_id = 0; group_id < group_size; group_id++)
{
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, 0);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
}
#else #else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 0
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
...@@ -125,7 +196,15 @@ __global__ void ...@@ -125,7 +196,15 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
// const index_t group_count #if 0
const index_t N,
const index_t K,
const index_t StrideB,
const index_t NPadded,
const index_t KPadded,
const index_t K0,
const index_t k_batch,
#endif
const index_t block_size) const index_t block_size)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
...@@ -159,27 +238,55 @@ __global__ void ...@@ -159,27 +238,55 @@ __global__ void
const index_t group_id = block_id / block_size; const index_t group_id = block_id / block_size;
#endif #endif
#if 0 #if 1
const auto N = gemm_desc_ptr[0].karg_.N;
const auto K = gemm_desc_ptr[0].karg_.K; #if 1
const auto StrideB = gemm_desc_ptr[0].karg_.StrideB; const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto NPadded = gemm_desc_ptr[0].karg_.NPadded; const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto KPadded = gemm_desc_ptr[0].karg_.KPadded; const auto K = gemm_desc_ptr[group_id].karg_.K;
const auto K0 = gemm_desc_ptr[0].karg_.KPadded; const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto k_batch = gemm_desc_ptr[0].karg_.k_batch; const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB;
const auto block_2_ctile_map = gemm_desc_ptr[0].block_2_ctile_map_; const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
#else
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
#endif
// const auto block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid, gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid, gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid, gemm_desc_ptr[group_id].karg_.p_c_grid,
gemm_desc_ptr[group_id].karg_.M, M,
N, N,
K, K,
gemm_desc_ptr[group_id].karg_.StrideA, StrideA,
StrideB, StrideB,
gemm_desc_ptr[group_id].karg_.StrideC, StrideC,
gemm_desc_ptr[group_id].karg_.MPadded, MPadded,
NPadded, NPadded,
KPadded, KPadded,
K0, K0,
...@@ -325,16 +432,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -325,16 +432,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
KernelArgument karg_; KernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_; // GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg, GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map, // GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start, index_t block_start,
index_t block_end) index_t block_end)
: karg_{karg}, : karg_{karg},
block_2_ctile_map_{b2c_map}, // block_2_ctile_map_{b2c_map},
block_start_{block_start}, block_start_{block_start},
block_end_{block_end} block_end_{block_end}
{ {
...@@ -404,10 +511,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -404,10 +511,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_; const index_t block_start = 0;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ = grid_size_grp;
// block-to-e-tile map // block-to-e-tile map
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
...@@ -428,8 +535,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -428,8 +535,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
k0, k0,
K_BATCH}; K_BATCH};
gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(std::move(karg),
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); // std::move(grouped_block_2_ctile_map),
block_start,
block_end);
} }
} }
...@@ -458,21 +567,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -458,21 +567,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_; const index_t block_start = 0;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ = grid_size_grp;
// block-to-e-tile map // block-to-e-tile map
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded; karg.KPadded = k_padded;
karg.K0 = k0; karg.K0 = k0;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; // gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end; gemm_kernel_args_[i].block_end_ = block_end;
} }
} }
...@@ -556,24 +665,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -556,24 +665,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
#if 1 #if 1
std::vector<GemmTransKernelArgMsN1K1> gemm_kernel_args_msn1k1_; std::vector<GemmTransKernelArgMsN1K1> gemm_kernel_args_msn1k1_;
index_t all_gemm_block_size = // index_t all_gemm_block_size =
arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_; // arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for(const auto& trans_arg : arg.gemm_kernel_args_) for(const auto& trans_arg : arg.gemm_kernel_args_)
{ {
auto karg = ArgumentMsN1K1{ auto karg = ArgumentMsN1K1{
trans_arg.karg_.p_a_grid, trans_arg.karg_.p_b_grid, trans_arg.karg_.p_c_grid}; trans_arg.karg_.p_a_grid, trans_arg.karg_.p_b_grid, trans_arg.karg_.p_c_grid};
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_; auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_ std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
// << " trans_arg.block_end_: " << trans_arg.block_end_ << " trans_arg.block_end_: " << trans_arg.block_end_
// << " block_size: " << block_size << std::endl; << " block_size: " << block_size << std::endl;
gemm_kernel_args_msn1k1_.push_back({karg}); gemm_kernel_args_msn1k1_.push_back({karg});
} }
#endif #endif
#if 0 #if 1
hip_check_error(hipMemcpy(arg.p_workspace_, hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(), arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
...@@ -653,14 +762,25 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -653,14 +762,25 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
shared_karg #if 0
// all_gemm_block_size arg.gemm_kernel_args_[0].karg_.N,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideB,
arg.gemm_kernel_args_[0].karg_.NPadded,
arg.gemm_kernel_args_[0].karg_.KPadded,
arg.gemm_kernel_args_[0].karg_.K0,
arg.gemm_kernel_args_[0].karg_.k_batch,
#elif 0
all_gemm_block_size
#elif 1
arg.gemm_kernel_args_.size()
#endif
); );
}; };
std::cout << "all_have_main_k0_block_loop: " << all_have_main_k0_block_loop std::cout << "all_have_main_k0_block_loop: " << all_have_main_k0_block_loop
<< " all_have_kbatch_gt_one: " << all_have_kbatch_gt_one << std::endl; << " all_have_kbatch_gt_one: " << all_have_kbatch_gt_one << std::endl;
#if 0 #if 1
if(all_have_main_k0_block_loop) if(all_have_main_k0_block_loop)
{ {
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment