Commit 06af86fb authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

WIP: Enabling gfx90a build

parent 9da21f99
...@@ -27,19 +27,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -27,19 +27,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ck::type_convert;
struct ExecutionConfig final struct ExecutionConfig final
{ {
int do_verification = 1; // (0=no, 1=CPU) int do_verification = 1; // (0=no, 1=CPU)
int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) int init_method = 10; // (0=no init, 1=integer value, 2=decimal value)
bool time_kernel = false; // (0=no, 1=yes) bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info) int verbosity = 1; // (0=no info, 1=verbose info)
}; };
struct ProblemSize final struct ProblemSize final
{ {
ck::index_t M = 3840; ck::index_t M = 256;
ck::index_t N = 4096; ck::index_t N = 256;
ck::index_t K = 4096; ck::index_t K = 384;
ck::index_t StrideA = -1; ck::index_t StrideA = -1;
ck::index_t StrideB = -1; ck::index_t StrideB = -1;
...@@ -139,11 +141,11 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -139,11 +141,11 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB| LDSTypeA| LDSTypeB|
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, ADataType, BDataType>;
// clang-format on // clang-format on
auto M = problem_size.M; auto M = problem_size.M;
...@@ -225,19 +227,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -225,19 +227,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "NOTE: No input data initialization." << std::endl; std::cout << "NOTE: No input data initialization." << std::endl;
} }
break; break;
case 1: case 10: // Initializations for development and debugging
case 2:
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k); ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(a_m_k_scale); ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.25f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.0f)}(b_k_n); ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.25f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale); ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale);
if(config.verbosity > 0) if(config.verbosity > 0)
{ {
std::cout << "Init A = {1}" << std::endl; std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {0.5}" << std::endl; std::cout << "Init A scale = {0.25}" << std::endl;
std::cout << "Init B = {1}" << std::endl; std::cout << "Init B = {0.25}" << std::endl;
std::cout << "Init B scale = {2.0}" << std::endl; std::cout << "Init B scale = {2.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl; std::cout << "Expect C = {K*(0.25*0.5)}" << std::endl;
} }
break; break;
...@@ -343,12 +344,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -343,12 +344,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Comparing results..." << std::endl; std::cout << "Comparing results..." << std::endl;
} }
if(config.init_method == 1) if(config.init_method == 10)
{ {
res_verified = auto expected = static_cast<float>(K) * (0.25f * 0.5f);
res_verified && std::abs(static_cast<float>(K) - c_m_n_device_result(0, 0)) <= 0.0f; auto computed = type_convert<float>(c_m_n_device_result(1, 12));
std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0)
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
<< std::endl;
} }
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
......
...@@ -13,7 +13,7 @@ using XDataType = ck::e8m0_bexp_t; ...@@ -13,7 +13,7 @@ using XDataType = ck::e8m0_bexp_t;
#endif #endif
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = float;
#if 1 #if 0
using CDataType = ck::half_t; using CDataType = ck::half_t;
#else #else
using CDataType = float; using CDataType = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment