Commit e542dfc4 authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

test

parent 8bfacf9f
......@@ -23,7 +23,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
#if 1
#if 0
template <typename GridwiseGemm,
typename GemmDesc,
typename GemmSharedArgs,
......@@ -114,7 +114,78 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 1
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_size)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
for(index_t group_id = 0; group_id < group_size; group_id++)
{
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, 0);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
}
#else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#elif 0
template <typename GridwiseGemm,
typename GemmDesc,
......@@ -125,7 +196,15 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
// const index_t group_count
#if 0
const index_t N,
const index_t K,
const index_t StrideB,
const index_t NPadded,
const index_t KPadded,
const index_t K0,
const index_t k_batch,
#endif
const index_t block_size)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
......@@ -159,27 +238,55 @@ __global__ void
const index_t group_id = block_id / block_size;
#endif
#if 0
const auto N = gemm_desc_ptr[0].karg_.N;
const auto K = gemm_desc_ptr[0].karg_.K;
const auto StrideB = gemm_desc_ptr[0].karg_.StrideB;
const auto NPadded = gemm_desc_ptr[0].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[0].karg_.KPadded;
const auto K0 = gemm_desc_ptr[0].karg_.KPadded;
const auto k_batch = gemm_desc_ptr[0].karg_.k_batch;
const auto block_2_ctile_map = gemm_desc_ptr[0].block_2_ctile_map_;
#if 1
#if 1
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto N = gemm_desc_ptr[group_id].karg_.N;
const auto K = gemm_desc_ptr[group_id].karg_.K;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch;
#else
const auto M = gemm_desc_ptr[group_id].karg_.M;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC;
#endif
// const auto block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
static constexpr index_t B2E_M01 = 8;
const index_t block_start = block_size * group_id;
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
gemm_desc_ptr[group_id].karg_.M,
M,
N,
K,
gemm_desc_ptr[group_id].karg_.StrideA,
StrideA,
StrideB,
gemm_desc_ptr[group_id].karg_.StrideC,
gemm_desc_ptr[group_id].karg_.MPadded,
StrideC,
MPadded,
NPadded,
KPadded,
K0,
......@@ -325,16 +432,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg
{
KernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
// GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map,
// GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
block_2_ctile_map_{b2c_map},
// block_2_ctile_map_{b2c_map},
block_start_{block_start},
block_end_{block_end}
{
......@@ -404,10 +511,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
const index_t block_start = 0;
const index_t block_end = grid_size_grp;
grid_size_ += grid_size_grp;
grid_size_ = grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
......@@ -428,8 +535,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
k0,
K_BATCH};
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
gemm_kernel_args_.emplace_back(std::move(karg),
// std::move(grouped_block_2_ctile_map),
block_start,
block_end);
}
}
......@@ -458,21 +567,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
const index_t block_start = 0;
const index_t block_end = grid_size_grp;
grid_size_ += grid_size_grp;
grid_size_ = grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0;
karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
karg.KPadded = k_padded;
karg.K0 = k0;
karg.k_batch = K_BATCH;
// gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
}
}
......@@ -556,24 +665,24 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
#if 1
std::vector<GemmTransKernelArgMsN1K1> gemm_kernel_args_msn1k1_;
index_t all_gemm_block_size =
arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
// index_t all_gemm_block_size =
// arg.gemm_kernel_args_[0].block_end_ - arg.gemm_kernel_args_[0].block_start_;
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
auto karg = ArgumentMsN1K1{
trans_arg.karg_.p_a_grid, trans_arg.karg_.p_b_grid, trans_arg.karg_.p_c_grid};
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
// << " trans_arg.block_end_: " << trans_arg.block_end_
// << " block_size: " << block_size << std::endl;
auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
<< " trans_arg.block_end_: " << trans_arg.block_end_
<< " block_size: " << block_size << std::endl;
gemm_kernel_args_msn1k1_.push_back({karg});
}
#endif
#if 0
#if 1
hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
......@@ -653,14 +762,25 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
shared_karg
// all_gemm_block_size
#if 0
arg.gemm_kernel_args_[0].karg_.N,
arg.gemm_kernel_args_[0].karg_.K,
arg.gemm_kernel_args_[0].karg_.StrideB,
arg.gemm_kernel_args_[0].karg_.NPadded,
arg.gemm_kernel_args_[0].karg_.KPadded,
arg.gemm_kernel_args_[0].karg_.K0,
arg.gemm_kernel_args_[0].karg_.k_batch,
#elif 0
all_gemm_block_size
#elif 1
arg.gemm_kernel_args_.size()
#endif
);
};
std::cout << "all_have_main_k0_block_loop: " << all_have_main_k0_block_loop
<< " all_have_kbatch_gt_one: " << all_have_kbatch_gt_one << std::endl;
#if 0
#if 1
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment