Commit 10fdada7 authored by Jing Zhang's avatar Jing Zhang
Browse files

rename e0_e1

parent 95228cd7
......@@ -12,14 +12,16 @@ template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
struct MatrixIndex
{
index_t k;
......@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t w;
};
static constexpr auto KPerThread = ThreadMatrixC{}.GetLength(I0);
static constexpr auto HPerThread = ThreadMatrixC{}.GetLength(I2);
static constexpr auto WPerThread = ThreadMatrixC{}.GetLength(I3);
// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
......@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
......@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
......@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
constexpr index_t HPerBlock = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t WPerBlock = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread;
constexpr auto num_h_threads = H / HPerThread;
constexpr auto num_w_threads = WPerBlock / WPerThread;
constexpr auto num_h_threads = HPerBlock / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / num_hw_threads;
......@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
......@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
}
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
}
......@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
private:
MatrixIndex c_thread_begin_mtx_idx_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
AThreadCopy a_thread_copy_;
};
......
......@@ -15,24 +15,24 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AEKGridDesc,
typename BENHoWoGridDesc,
typename CKNHoWoGridDesc,
typename CBlockIdToKNHoWoBlockClusterAdaptor,
typename AGridDesc_E0_E1_K,
typename BGridDesc_E_N_Ho_Wo,
typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v2(
const FloatAB* __restrict__ p_a_grid,
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AEKGridDesc a_e_k_grid_desc,
const BENHoWoGridDesc b_e_n_ho_wo_grid_desc,
const CKNHoWoGridDesc c_k_n_ho_wo_grid_desc,
const CBlockIdToKNHoWoBlockClusterAdaptor c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
const AGridDesc_E0_E1_K a_e0_e1_k_grid_desc,
const BGridDesc_E_N_Ho_Wo b_e0_e1_n_ho_wo_grid_desc,
const CGridDesc_K_N_Ho_Wo c_k_n_ho_wo_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -43,8 +43,8 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_e_k_grid_desc,
b_e_n_ho_wo_grid_desc,
a_e0_e1_k_grid_desc,
b_e0_e1_n_ho_wo_grid_desc,
c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -56,10 +56,10 @@ __global__ void
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AEKGridDesc,
typename BENHoWoGridDesc,
typename CKNHoWoGridDesc,
typename CBlockIdToKNHoWoBlockClusterAdaptor,
typename AGridDesc_E0_E1_K,
typename BGridDesc_E_N_Ho_Wo,
typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
......@@ -69,19 +69,19 @@ __global__ void
kernel_gemm_dlops_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_e_k_grid_desc,
const void CONSTANT* p_b_e_n_ho_wo_grid_desc,
const void CONSTANT* p_a_e0_e1_k_grid_desc,
const void CONSTANT* p_b_e0_e1_n_ho_wo_grid_desc,
const void CONSTANT* p_c_k_n_ho_wo_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_e_k_grid_desc = *reinterpret_cast<const AEKGridDesc*>(
cast_pointer_to_generic_address_space(p_a_e_k_grid_desc));
const auto b_e_n_ho_wo_grid_desc = *reinterpret_cast<const BENHoWoGridDesc*>(
cast_pointer_to_generic_address_space(p_b_e_n_ho_wo_grid_desc));
const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CKNHoWoGridDesc*>(
const auto a_e0_e1_k_grid_desc = *reinterpret_cast<const AGridDesc_E0_E1_K*>(
cast_pointer_to_generic_address_space(p_a_e0_e1_k_grid_desc));
const auto b_e0_e1_n_ho_wo_grid_desc = *reinterpret_cast<const BGridDesc_E_N_Ho_Wo*>(
cast_pointer_to_generic_address_space(p_b_e0_e1_n_ho_wo_grid_desc));
const auto c_k_n_ho_wo_grid_desc = *reinterpret_cast<const CGridDesc_K_N_Ho_Wo*>(
cast_pointer_to_generic_address_space(p_c_k_n_ho_wo_grid_desc));
constexpr index_t shared_block_size =
......@@ -93,8 +93,8 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_e_k_grid_desc,
b_e_n_ho_wo_grid_desc,
a_e0_e1_k_grid_desc,
b_e0_e1_n_ho_wo_grid_desc,
c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -106,9 +106,9 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename AGlobalDesc_E0_E1_K,
typename BGlobalDesc_E0_E1_N_Ho_Wo,
typename CGlobalDesc_K_N_Ho_Wo,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
......@@ -148,12 +148,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(a_e1_k_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
......@@ -163,9 +163,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
const AGlobalDesc& a_e_k_global_desc,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
const AGlobalDesc_E0_E1_K& a_e0_e1_k_global_desc,
const BGlobalDesc_E0_E1_N_Ho_Wo& b_e0_e1_n_ho_wo_global_desc,
const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
......@@ -175,18 +175,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
p_a_global, a_e0_e1_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
p_b_global, b_e0_e1_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
// const auto E = a_e_k_global_desc.GetLength(I0);
// const auto K = a_e_k_global_desc.GetLength(I1);
// const auto E = a_e0_e1_k_global_desc.GetLength(I0);
// const auto K = a_e0_e1_k_global_desc.GetLength(I1);
// const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// const auto N = b_e0_e1_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e0_e1_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e0_e1_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
......@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
constexpr auto a_e1_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e2_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
constexpr auto b_e2_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
......@@ -240,12 +240,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(a_e2_k_block_desc),
decltype(b_e2_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
......@@ -275,8 +272,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e_k_global_desc),
decltype(a_e_k_desc),
decltype(a_e0_e1_k_global_desc),
decltype(a_e1_k_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
......@@ -286,30 +283,30 @@ struct GridwiseGemmDlops_km_kn_mn_v3
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e_k_global_desc,
true>(a_e0_e1_k_global_desc,
make_multi_index(0, k_block_data_on_global),
a_e_k_desc,
a_e1_k_block_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
constexpr auto b_e2_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
decltype(b_e0_e1_n_ho_wo_global_desc),
decltype(b_e2_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
1,
true>(
b_e_n_ho_wo_global_desc,
b_e0_e1_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize());
p_shared_block, a_e1_k_block_desc.GetElementSpaceSize());
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
......@@ -327,34 +324,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// constexpr auto a_e_k_global_move_slice_window_step_hack =
// AGlobalMoveSliceWindowStepHacks{}; constexpr auto
// b_e_n_ho_wo_global_move_slice_window_step_hack = BGlobalMoveSliceWindowStepHacks{};
constexpr auto a_e0_e1_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e0_e1_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
b_e2_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
a_blockwise_copy.RunRead(
a_e0_e1_k_global_desc, a_global_buf, a_e0_e1_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
b_e0_e1_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
a_blockwise_copy.RunWrite(a_e1_k_block_desc, a_block_buf);
}
__syncthreads();
......@@ -368,36 +360,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock;
......@@ -407,20 +399,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
b_e2_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
b_e0_e1_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
......
......@@ -26,14 +26,6 @@ template <typename FloatA,
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
__device__ ThreadwiseGemmDlops_km_kn_mn_v3()
{
static_assert(AThreadDesc_E_K::IsKnownAtCompileTime() &&
BThreadDesc_E_N_Ho_Wo::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment