Commit d6ea89ec authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Add descriptor and RTC workarounds for batched_gemm_multiple_d_gemm_multiple_d.

parent d20c20a6
......@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
prob.N,
prob.K,
prob.O,
x.tile_desc.gemm0_m_per_block,
x.tile_desc.gemm01_m_per_block,
x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm1_n_per_block,
......@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std::unordered_map<std::string, std::string> values = {
{"name",
std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm0_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.a0k1) + "_" + std::to_string(this->tile_desc.b0k1) +
"_" + std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
......@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
MakeTuple(Transform(this->D1s, [](auto tensor) { return ToString(tensor.layout); }))},
{"E1Layout", ToString(this->E1.layout)},
{"ADataType", ToString(this->A0.element)},
{"A0DataType", ToString(this->A0.element)},
{"B0DataType", ToString(this->B0.element)},
{"Acc0DataType", ToString(this->acc_type)},
{"D0sDataType",
......@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{"PadGemm1N", std::to_string(this->padding_desc.pad_gemm1_n)},
{"PadGemm1K", std::to_string(this->padding_desc.pad_gemm1_k)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemm0k_prefetch_stage)},
{"NumGemm0KPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm0_m_per_block)},
{"Gemm0MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"A0K1", std::to_string(this->tile_desc.a0k1)},
{"B0K1", std::to_string(this->tile_desc.b0k1)},
{"A0K1", std::to_string(this->tile_desc.ak1)},
{"B0K1", std::to_string(this->tile_desc.bk1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
......
......@@ -3,8 +3,10 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <vector>
#endif
#include "device_base.hpp"
......@@ -31,6 +33,7 @@ template <typename A0Layout,
typename CDE1ElementwiseOperation>
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
{
#ifndef __HIPCC_RTC__
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
......@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
CDE1ElementwiseOperation cde1_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
} // namespace device
......
......@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return false;
}
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
return false;
}
// if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
// {
// return false;
// }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
......@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
else
{
static_for<0, acc0_thread_buf.Size(), 1>{}(
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); });
[&](auto i) { cde0_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
}
// gemm1
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment