Unverified Commit 60914583 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Add examples of batched/grouped/SplitK Gemm for int8/bfp16/fp16/fp32 (#361)



* add examples into grouped/batched_gemm

* adding splitK examples

* fixed splitK

* add bfp16 int8 example into splitK

* formatting

* use static_cast

* added common for batched_gemm

* add commons for examples of splitK/batched/grouped_gemm

* return true

* adjust splitK check tol

* update example
Co-authored-by: default avatarChao Liu <lc.roy86@gmail.com>
parent 2327f1a6
...@@ -95,7 +95,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -95,7 +95,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, static_cast<void*>(p_shared_block),
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -270,7 +270,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -270,7 +270,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -463,8 +463,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -463,8 +463,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = static_cast<FloatAB*>(p_shared_block);
FloatAB* p_b_block = p_shared_block + a_block_space_size; FloatAB* p_b_block = static_cast<FloatAB*>(p_shared_block) + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
...@@ -547,11 +547,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -547,11 +547,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static_cast<FloatC*>(p_shared_block), static_cast<FloatC*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(M1 == MWave, "");
static_assert(N1 == NWave, "");
static_assert(M2 * M3 * M4 == MPerXDL, "");
static_assert(N2 == NPerXDL, "");
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock, c_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple( make_tuple(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment