Commit d6ea89ec authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Add descriptor and RTC workarounds for batched_gemm_multiple_d_gemm_multiple_d.

parent d20c20a6
...@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations( ...@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
prob.N, prob.N,
prob.K, prob.K,
prob.O, prob.O,
x.tile_desc.gemm0_m_per_block, x.tile_desc.gemm01_m_per_block,
x.tile_desc.gemm0_n_per_block, x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_k_per_block, x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm1_n_per_block, x.tile_desc.gemm1_n_per_block,
...@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std::unordered_map<std::string, std::string> values = { std::unordered_map<std::string, std::string> values = {
{"name", {"name",
std::to_string(this->tile_desc.block_size) + "_" + std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm0_m_per_block) + "_" + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.a0k1) + "_" + std::to_string(this->tile_desc.b0k1) + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
"_" + std::to_string(this->tile_desc.b1k1) + "_" + std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" + std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" + std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
...@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))}, MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))},
{"E1Layout", ToString(this->E1.layout)}, {"E1Layout", ToString(this->E1.layout)},
{"ADataType", ToString(this->A0.element)}, {"A0DataType", ToString(this->A0.element)},
{"B0DataType", ToString(this->B0.element)}, {"B0DataType", ToString(this->B0.element)},
{"Acc0DataType", ToString(this->acc_type)}, {"Acc0DataType", ToString(this->acc_type)},
{"D0sDataType", {"D0sDataType",
...@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const ...@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)}, {"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)},
{"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)}, {"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemm0k_prefetch_stage)}, {"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)}, {"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm0_m_per_block)}, {"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"A0K1", std::to_string(this->tile_desc.a0k1)}, {"A0K1", std::to_string(this->tile_desc.ak1)},
{"B0K1", std::to_string(this->tile_desc.b0k1)}, {"B0K1", std::to_string(this->tile_desc.bk1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)}, {"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#endif
#include "device_base.hpp" #include "device_base.hpp"
...@@ -31,6 +33,7 @@ template <typename A0Layout, ...@@ -31,6 +33,7 @@ template <typename A0Layout,
typename CDE1ElementwiseOperation> typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{ {
#ifndef __HIPCC_RTC__
static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size();
...@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator ...@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
CDE1ElementwiseOperation cde1_element_op) = 0; CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
}; };
} // namespace device } // namespace device
......
...@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return false; return false;
} }
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n)) // if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{ // {
return false; // return false;
} // }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
...@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
else else
{ {
static_for<0, acc0_thread_buf.Size(), 1>{}( static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); }); [&](auto i) { cde0_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
} }
// gemm1 // gemm1
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment