Commit 4511f877 authored by Chao Liu's avatar Chao Liu
Browse files

refactor profiler

parent 519b6aaf
...@@ -156,6 +156,8 @@ template <typename ADataType, ...@@ -156,6 +156,8 @@ template <typename ADataType,
struct DeviceGemmXdlSplitK struct DeviceGemmXdlSplitK
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
using DeviceOp = DeviceGemmXdlSplitK;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -426,7 +428,6 @@ struct DeviceGemmXdlSplitK ...@@ -426,7 +428,6 @@ struct DeviceGemmXdlSplitK
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(k_batch), BatchCount_(k_batch),
has_tail_(false),
compute_ptr_offset_of_batch_{0, 0}, compute_ptr_offset_of_batch_{0, 0},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -444,27 +445,15 @@ struct DeviceGemmXdlSplitK ...@@ -444,27 +445,15 @@ struct DeviceGemmXdlSplitK
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KSplitted, N, StrideB); DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KSplitted, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
bool is_valid = GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_);
if(K != KSplitted * BatchCount_) if(K != KSplitted * BatchCount_)
{ {
has_tail_ = true;
const auto KTail = K - KSplitted * (BatchCount_ - 1); const auto KTail = K - KSplitted * (BatchCount_ - 1);
a_grid_desc_k0_m_k1_tail_ = a_grid_desc_k0_m_k1_tail_ =
DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1(M, KTail, StrideA); DeviceGemmXdlSplitK::MakeAGridDescriptor_K0_M_K1(M, KTail, StrideA);
b_grid_desc_k0_n_k1_tail_ = b_grid_desc_k0_n_k1_tail_ =
DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KTail, N, StrideB); DeviceGemmXdlSplitK::MakeBGridDescriptor_K0_N_K1(KTail, N, StrideB);
is_valid &= GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_tail_,
b_grid_desc_k0_n_k1_tail_,
c_grid_desc_m_n_,
M01_,
N01_);
} }
if(is_valid)
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
...@@ -494,8 +483,12 @@ struct DeviceGemmXdlSplitK ...@@ -494,8 +483,12 @@ struct DeviceGemmXdlSplitK
compute_ptr_offset_of_batch_ = compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride}; ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_, c_grid_desc_m_n_, M01, N01);
} block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_,
c_grid_desc_m_n_.GetLength(I0),
c_grid_desc_m_n_.GetLength(I1),
M01,
N01);
} }
// private: // private:
...@@ -503,7 +496,6 @@ struct DeviceGemmXdlSplitK ...@@ -503,7 +496,6 @@ struct DeviceGemmXdlSplitK
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t BatchCount_;
bool has_tail_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_tail_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_tail_;
...@@ -526,6 +518,11 @@ struct DeviceGemmXdlSplitK ...@@ -526,6 +518,11 @@ struct DeviceGemmXdlSplitK
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! not supported");
}
{ {
std::cout << "k_batch = " << arg.BatchCount_ << "\n"; std::cout << "k_batch = " << arg.BatchCount_ << "\n";
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
...@@ -539,8 +536,6 @@ struct DeviceGemmXdlSplitK ...@@ -539,8 +536,6 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
if(arg.has_tail_)
{
std::cout << "arg.a_grid_desc_k0_m_k1_tail_{" std::cout << "arg.a_grid_desc_k0_m_k1_tail_{"
<< arg.a_grid_desc_k0_m_k1_tail_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_tail_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_tail_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_tail_.GetLength(I1) << ", "
...@@ -551,35 +546,12 @@ struct DeviceGemmXdlSplitK ...@@ -551,35 +546,12 @@ struct DeviceGemmXdlSplitK
<< arg.b_grid_desc_k0_n_k1_tail_.GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_tail_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_tail_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_k0_n_k1_tail_.GetLength(I2) << "}" << std::endl;
} }
}
bool is_valid = GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
if(arg.has_tail_)
{
is_valid &= GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
if(!is_valid)
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
float ave_time = 0; float ave_time = 0;
if(arg.has_tail_)
{
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
const auto K0_tail = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K0_tail = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
...@@ -684,81 +656,6 @@ struct DeviceGemmXdlSplitK ...@@ -684,81 +656,6 @@ struct DeviceGemmXdlSplitK
ave_time = Run(kernel); ave_time = Run(kernel);
} }
}
else
{
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
if(has_main_k0_block_loop)
{
const auto kernel = ck::kernel_gemm_xdl_splitk<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel = ck::kernel_gemm_xdl_splitk<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdlSplitK::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdlSplitK::BGridDesc_K0_N_K1>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
}
return ave_time; return ave_time;
} }
...@@ -782,6 +679,11 @@ struct DeviceGemmXdlSplitK ...@@ -782,6 +679,11 @@ struct DeviceGemmXdlSplitK
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_) &&
GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_tail_,
arg.b_grid_desc_k0_n_k1_tail_,
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_); arg.N01_);
} }
......
...@@ -167,15 +167,12 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -167,15 +167,12 @@ struct DeviceGemmXdlSplitKCShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
template <index_t K1> static auto GetActualBatchAndKSplitted(index_t KRaw, index_t KBatch)
static auto GetActualBatchAndKSplitted(index_t K, index_t KBatch)
{ {
const index_t K0PerBlock = KPerBlock / K1; const index_t KSplitted = math::integer_divide_ceil(KRaw, KPerBlock * KBatch) * KPerBlock;
const index_t K0 = math::integer_divide_ceil(K, KPerBlock * KBatch) * K0PerBlock; const index_t actual_k_batch = math::integer_divide_ceil(KRaw, KSplitted);
const index_t KSplitted = K0 * K1;
const index_t actual_batch = math::integer_divide_ceil(K, KSplitted);
return std::make_pair(actual_batch, KSplitted); return std::make_pair(actual_k_batch, KSplitted);
} }
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
...@@ -426,7 +423,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -426,7 +423,7 @@ struct DeviceGemmXdlSplitKCShuffle
} }
private: private:
// TODO: should they be long_index_t? // TODO: should we use long_index_t?
index_t BatchStrideA_; index_t BatchStrideA_;
index_t BatchStrideB_; index_t BatchStrideB_;
}; };
...@@ -502,81 +499,61 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -502,81 +499,61 @@ struct DeviceGemmXdlSplitKCShuffle
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
const auto actual_batch_and_ksplitted_A = const auto actual_batch_and_ksplitted = GetActualBatchAndKSplitted(KRaw, k_batch);
GetActualBatchAndKSplitted<AK1>(KRaw, k_batch);
const auto actual_batch_and_ksplitted_B =
GetActualBatchAndKSplitted<BK1>(KRaw, k_batch);
assert(actual_batch_and_ksplitted_A.first == actual_batch_and_ksplitted_B.first); BatchCount_ = actual_batch_and_ksplitted.first;
BatchCount_ = actual_batch_and_ksplitted_A.first; const auto KSplitted = actual_batch_and_ksplitted.second;
const auto AKSplitted = actual_batch_and_ksplitted_A.second;
const auto BKSplitted = actual_batch_and_ksplitted_B.second;
a_grid_desc_ak0_m_ak1_ = a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKSplitted, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KSplitted, StrideA);
b_grid_desc_bk0_n_bk1_ = b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKSplitted, NRaw, StrideB); DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KSplitted, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
is_valid_ = GridwiseGemm::CheckValidity( if(KRaw != KSplitted * BatchCount_ || KRaw != KSplitted * BatchCount_)
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_);
if(KRaw != AKSplitted * BatchCount_ || KRaw != BKSplitted * BatchCount_)
{ {
has_tail_ = true; const auto KTail = KRaw - KSplitted * (BatchCount_ - 1);
const auto AKTail = KRaw - AKSplitted * (BatchCount_ - 1);
const auto BKTail = KRaw - BKSplitted * (BatchCount_ - 1);
a_grid_desc_ak0_m_ak1_tail_ = a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, AKTail, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KTail, StrideA);
b_grid_desc_bk0_n_bk1_tail_ = b_grid_desc_bk0_n_bk1_tail_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BKTail, NRaw, StrideB); DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KTail, NRaw, StrideB);
is_valid_ &= GridwiseGemm::CheckValidity(
a_grid_desc_ak0_m_ak1_tail_, b_grid_desc_bk0_n_bk1_tail_, c_grid_desc_m_n_);
} }
if(is_valid_)
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
const index_t a_batch_stride = [AKSplitted, StrideA]() { const index_t a_batch_stride = [KSplitted, StrideA]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
ignore = StrideA; ignore = StrideA;
return AKSplitted; return KSplitted;
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return AKSplitted * StrideA; return KSplitted * StrideA;
} }
}(); }();
const index_t b_batch_stride = [BKSplitted, StrideB]() { const index_t b_batch_stride = [KSplitted, StrideB]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return BKSplitted * StrideB; return KSplitted * StrideB;
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
ignore = StrideB; ignore = StrideB;
return BKSplitted; return KSplitted;
} }
}(); }();
compute_ptr_offset_of_batch_ = compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride}; ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_, block_2_ctile_map_ = MakeBlock2CTileMap(
c_grid_desc_m_n_.GetLength(I0), BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), 1, 1);
c_grid_desc_m_n_.GetLength(I1),
1,
1);
}
} }
// private: // private:
...@@ -584,8 +561,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -584,8 +561,6 @@ struct DeviceGemmXdlSplitKCShuffle
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t BatchCount_;
bool has_tail_;
bool is_valid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_;
...@@ -607,6 +582,11 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -607,6 +582,11 @@ struct DeviceGemmXdlSplitKCShuffle
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! not supported");
}
{ {
std::cout << "k_batch = " << arg.BatchCount_ << "\n"; std::cout << "k_batch = " << arg.BatchCount_ << "\n";
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
...@@ -622,9 +602,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -622,9 +602,6 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
if(arg.has_tail_)
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_tail_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_tail_{"
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I1) << ", " << arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I1) << ", "
...@@ -635,13 +612,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -635,13 +612,6 @@ struct DeviceGemmXdlSplitKCShuffle
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I1) << ", " << arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_bk0_n_bk1_tail_.GetLength(I2) << "}" << std::endl;
} }
}
if(!arg.is_valid_)
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
...@@ -651,39 +621,12 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -651,39 +621,12 @@ struct DeviceGemmXdlSplitKCShuffle
float ave_time = 0; float ave_time = 0;
if(arg.has_tail_)
{
const auto K0_tail = arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0); const auto K0_tail = arg.a_grid_desc_ak0_m_ak1_tail_.GetLength(I0);
const bool tail_has_main_k0_block_loop = const bool tail_has_main_k0_block_loop =
GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail); GridwiseGemm::CalculateHasMainK0BlockLoop(K0_tail);
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat == 0) return launch_and_time_kernel(kernel,
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
return 0.0f;
}
else
{
return launch_and_time_kernel(
kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -702,7 +645,6 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -702,7 +645,6 @@ struct DeviceGemmXdlSplitKCShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop && tail_has_main_k0_block_loop) if(has_main_k0_block_loop && tail_has_main_k0_block_loop)
...@@ -781,90 +723,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -781,90 +723,7 @@ struct DeviceGemmXdlSplitKCShuffle
ave_time = Run(kernel); ave_time = Run(kernel);
} }
}
else
{
const auto Run = [&](const auto& kernel) {
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
return 0.0f;
}
else
{
return launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.BatchCount_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
}
};
if(has_main_k0_block_loop)
{
const auto kernel = ck::kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true>;
ave_time = Run(kernel);
}
else
{
const auto kernel = ck::kernel_batched_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false>;
ave_time = Run(kernel);
}
}
return ave_time; return ave_time;
} }
...@@ -884,7 +743,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -884,7 +743,10 @@ struct DeviceGemmXdlSplitKCShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_) &&
GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_m_n_);
} }
// polymorphic // polymorphic
......
...@@ -62,8 +62,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -62,8 +62,8 @@ struct ReferenceGemm : public device::BaseOperator
float v_a; float v_a;
float v_b; float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, ck::type_convert<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, ck::type_convert<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
...@@ -72,7 +72,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -72,7 +72,7 @@ struct ReferenceGemm : public device::BaseOperator
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = v_c; arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
......
...@@ -24,6 +24,7 @@ function(add_instance_library INSTANCE_NAME) ...@@ -24,6 +24,7 @@ function(add_instance_library INSTANCE_NAME)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_splitk)
add_subdirectory(gemm_bias2d) add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_bias_relu_add)
...@@ -34,7 +35,6 @@ add_subdirectory(conv2d_fwd) ...@@ -34,7 +35,6 @@ add_subdirectory(conv2d_fwd)
add_subdirectory(conv3d_fwd) add_subdirectory(conv3d_fwd)
add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu)
add_subdirectory(conv2d_fwd_bias_relu_add) add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_fwd_bias_relu_atomic_add)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
......
...@@ -12,10 +12,10 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE ...@@ -12,10 +12,10 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp; device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp; device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp; device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp; device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp; device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp; device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp; device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instance.cpp;
) )
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
......
...@@ -23,7 +23,7 @@ using AccData = int32_t; ...@@ -23,7 +23,7 @@ using AccData = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< using device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instances = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
...@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< ...@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( void add_device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace device_batched_gemm_instance
......
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using AData = int8_t;
using BData = int8_t;
using CData = int8_t;
using AccData = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instances = std::tuple<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>
// clang-format on
>;
void add_device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -23,7 +23,7 @@ using AccData = int32_t; ...@@ -23,7 +23,7 @@ using AccData = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< using device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instances = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
...@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< ...@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( void add_device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace device_batched_gemm_instance
......
...@@ -23,7 +23,7 @@ using AccData = int32_t; ...@@ -23,7 +23,7 @@ using AccData = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< using device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instances = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
...@@ -45,11 +45,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< ...@@ -45,11 +45,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( void add_device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instances{});
} }
} // namespace device_batched_gemm_instance } // namespace device_batched_gemm_instance
......
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using AData = int8_t;
using BData = int8_t;
using CData = int8_t;
using AccData = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>
// clang-format on
>;
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{});
}
} // namespace device_batched_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# device_conv2d_fwd_bias_relu_atomic_add_instance
set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC)
set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_bias_relu_atomic_add_instance)
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_atomic_add_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasActivationPtr<PassThrough, PassThrough, AddRelu>>&
instance_container)
{
using Instances =
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances;
const auto instances = Instances{};
ck::static_for<0, std::tuple_size_v<Instances>, 1>{}([&](auto i) {
using Instance = remove_cvref_t<decltype(std::get<i>(instances))>;
auto instance = Instance{};
instance_container.push_back(std::make_unique<Instance>(instance));
});
}
} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -8,10 +8,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ...@@ -8,10 +8,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
...@@ -25,14 +25,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ...@@ -25,14 +25,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
......
...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = ...@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances =
// clang-format on // clang-format on
>; >;
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances{}); device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = ...@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances =
// clang-format on // clang-format on
>; >;
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances{}); device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = ...@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances =
// clang-format on // clang-format on
>; >;
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances{}); device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances =
std::tuple< std::tuple<
// clang-format off // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...@@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = ...@@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances =
// clang-format on // clang-format on
>; >;
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances{}); device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
# device_gemm_instance
set(DEVICE_GEMM_SPLITK_INSTANCE_SOURCE
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
)
add_library(device_gemm_splitk_instance SHARED ${DEVICE_GEMM_SPLITK_INSTANCE_SOURCE})
target_compile_features(device_gemm_splitk_instance PUBLIC)
set_target_properties(device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_splitk_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_splitk_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment