Commit 3c7e8da2 authored by j4yan's avatar j4yan
Browse files

rename variables and functions in gridwise_gemm_dlops_v1r3

parent 2faeaece
...@@ -228,13 +228,13 @@ struct DeviceGemmDlops ...@@ -228,13 +228,13 @@ struct DeviceGemmDlops
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBK0N0N1K1GridDescriptor(BGridDesc_K0_N_K1{})); decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 = using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -261,10 +261,10 @@ struct DeviceGemmDlops ...@@ -261,10 +261,10 @@ struct DeviceGemmDlops
c_grid_desc_m0_m10_m11_n0_n10_n11_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01} N01_{N01},
// a_element_op_{a_element_op}, a_element_op_{a_element_op},
// b_element_op_{b_element_op}, b_element_op_{b_element_op},
// c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
...@@ -274,14 +274,14 @@ struct DeviceGemmDlops ...@@ -274,14 +274,14 @@ struct DeviceGemmDlops
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{ {
a_grid_desc_k0_m0_m1_k1_ = a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_grid_desc_k0_m_k1_); GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ = b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_grid_desc_k0_n_k1_); GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ = c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_);
} }
} }
...@@ -300,12 +300,14 @@ struct DeviceGemmDlops ...@@ -300,12 +300,14 @@ struct DeviceGemmDlops
DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
// TODO: unused, but may be useful in future.
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
// AElementwiseOperation a_element_op_; // TODO: unused since gridwise_gemm_dlops_v1r3 does NOT support prologue for the time being.
// BElementwiseOperation b_element_op_; AElementwiseOperation a_element_op_;
// CElementwiseOperation c_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
}; };
// Invoker // Invoker
...@@ -317,14 +319,14 @@ struct DeviceGemmDlops ...@@ -317,14 +319,14 @@ struct DeviceGemmDlops
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment