Commit 857010cc authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent 6af8e20b
...@@ -83,15 +83,11 @@ int main(int argc, char* argv[]) ...@@ -83,15 +83,11 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
// int M = 256 + 256 * i; int M = 256 + 256 * i;
// int N = 128 + 128 * i; int N = 128 + 128 * i;
// int K = 64 + 64 * i; int K = 64 + 64 * i;
int M = 3840; gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr});
int N = 1024;
int K = 4096;
gemm_shapes.push_back({M, N, K, K, N, N, nullptr, nullptr, nullptr});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
......
...@@ -223,7 +223,7 @@ struct DeviceGroupedGemmXdl ...@@ -223,7 +223,7 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
NumPrefetch>; NumPrefetch>;
struct GemmDesc struct GemmDescKernelArg
{ {
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
...@@ -231,11 +231,13 @@ struct DeviceGroupedGemmXdl ...@@ -231,11 +231,13 @@ struct DeviceGroupedGemmXdl
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
const ADataType* a_ptr; const ADataType* a_ptr;
const BDataType* b_ptr; const BDataType* b_ptr;
CDataType* c_ptr; CDataType* c_ptr;
ck::index_t BlockStart, BlockEnd; ck::index_t BlockStart, BlockEnd;
}; };
...@@ -255,9 +257,11 @@ struct DeviceGroupedGemmXdl ...@@ -255,9 +257,11 @@ struct DeviceGroupedGemmXdl
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
grid_size = 0; grid_size = 0;
group_count_ = 0;
for(index_t i = 0; i < gemm_shapes.size(); i++) for(index_t i = 0; i < gemm_shapes.size(); i++)
{ {
group_count_++;
const index_t M = gemm_shapes[i].M; const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_shapes[i].N;
const index_t K = gemm_shapes[i].K; const index_t K = gemm_shapes[i].K;
...@@ -289,7 +293,8 @@ struct DeviceGroupedGemmXdl ...@@ -289,7 +293,8 @@ struct DeviceGroupedGemmXdl
const auto block_2_ctile_map_ = const auto block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
GemmShape_.push_back(GemmDesc{a_grid_desc_k0_m_k1_, gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
...@@ -306,11 +311,12 @@ struct DeviceGroupedGemmXdl ...@@ -306,11 +311,12 @@ struct DeviceGroupedGemmXdl
// private: // private:
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
index_t group_count_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
std::vector<GemmDesc> GemmShape_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
index_t grid_size; index_t grid_size;
}; };
...@@ -322,33 +328,40 @@ struct DeviceGroupedGemmXdl ...@@ -322,33 +328,40 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
StaticallyIndexedArray<GemmDesc, MaxGroupCount> GemmShape_arg; StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg;
bool has_main_k0_block_loop = true; bool has_main_k0_block_loop = true;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.GemmShape_.size()) if(i < arg.gemm_desc_kernel_arg_.size())
{ {
GemmShape_arg(i) = arg.GemmShape_[i]; gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i];
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0)
<< GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << ", "
<< GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1)
<< ", "
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{" std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< GemmShape_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0)
<< GemmShape_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << ", "
<< GemmShape_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1)
<< ", "
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2)
<< "}";
std::cout << ", arg.c_grid_desc_m_n_{ " std::cout << ", arg.c_grid_desc_m_n_{ "
<< GemmShape_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", " << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< GemmShape_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}" << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl; << std::endl;
if(!GridwiseGemm::CheckValidity(GemmShape_arg[i].a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(
GemmShape_arg[i].b_grid_desc_k0_n_k1_, gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_,
GemmShape_arg[i].c_grid_desc_m_n_, gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_)) arg.N01_))
{ {
...@@ -356,7 +369,7 @@ struct DeviceGroupedGemmXdl ...@@ -356,7 +369,7 @@ struct DeviceGroupedGemmXdl
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const auto K0 = GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0); const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0);
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop) if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop)
{ {
...@@ -373,7 +386,7 @@ struct DeviceGroupedGemmXdl ...@@ -373,7 +386,7 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDesc>, remove_reference_t<GemmDescKernelArg>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -385,8 +398,8 @@ struct DeviceGroupedGemmXdl ...@@ -385,8 +398,8 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size), dim3(arg.grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
GemmShape_arg, gemm_desc_kernel_arg_arg,
arg.GemmShape_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
...@@ -397,7 +410,7 @@ struct DeviceGroupedGemmXdl ...@@ -397,7 +410,7 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDesc>, remove_reference_t<GemmDescKernelArg>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -409,8 +422,8 @@ struct DeviceGroupedGemmXdl ...@@ -409,8 +422,8 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size), dim3(arg.grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
GemmShape_arg, gemm_desc_kernel_arg_arg,
arg.GemmShape_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
...@@ -434,17 +447,10 @@ struct DeviceGroupedGemmXdl ...@@ -434,17 +447,10 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
bool isValid = true; if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_)
for(int i = 0; i < arg.GemmShape_.size(); i++) return false;
{ else
return true;
isValid &= GridwiseGemm::CheckValidity(arg.GemmShape_[i].a_grid_desc_k0_m_k1_,
arg.GemmShape_[i].b_grid_desc_k0_n_k1_,
arg.GemmShape_[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
return isValid;
} }
// polymorphic // polymorphic
......
...@@ -29,24 +29,23 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< ...@@ -29,24 +29,23 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>, DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1> DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on // clang-format on
>; >;
......
...@@ -69,6 +69,12 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -69,6 +69,12 @@ void profile_grouped_gemm_impl(int do_verification,
} }
}; };
if(!(Ms.size() == Ns.size() && Ns.size() == Ks.size() && Ks.size() == StrideAs.size() &&
StrideAs.size() == StrideBs.size() && StrideBs.size() == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent Ms, Ns, Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k; std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n; std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_device_results; std::vector<Tensor<CDataType>> c_m_n_device_results;
...@@ -83,11 +89,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -83,11 +89,9 @@ void profile_grouped_gemm_impl(int do_verification,
c_m_n_device_results.push_back( c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
std::cout << "a_m_k[" << i << "]:" << a_m_k[i].mDesc << std::endl; std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
std::cout << "b_k_n[" << i << "]:" << b_k_n[i].mDesc << std::endl; << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
std::cout << "c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc
<< std::endl;
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
switch(init_method) switch(init_method)
...@@ -129,6 +133,7 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -129,6 +133,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize()));
b_device_buf.push_back( b_device_buf.push_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSize()));
c_device_buf.push_back(std::make_unique<DeviceMem>( c_device_buf.push_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSize())); sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSize()));
...@@ -254,10 +259,9 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -254,10 +259,9 @@ void profile_grouped_gemm_impl(int do_verification,
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
{ {
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ms[i] + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i]; sizeof(CDataType) * Ms[i] * Ns[i];
} }
......
...@@ -75,11 +75,6 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -75,11 +75,6 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideBs = stringToArray(argv[12]); const auto StrideBs = stringToArray(argv[12]);
const auto StrideCs = stringToArray(argv[13]); const auto StrideCs = stringToArray(argv[13]);
for(int i = 0; i < Ms.size(); i++)
{
std::cout << "M: " << Ms[i] << " N: " << Ns[i] << " K: " << Ks[i] << std::endl;
}
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment