"docs/vscode:/vscode.git/clone" did not exist on "637628a70fc708057cfd6dfe8717ca9035553bc8"
Commit 873d0958 authored by ltqin's avatar ltqin
Browse files

fix M N PerXdlops

parent 10a2ae2f
......@@ -50,7 +50,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......@@ -198,6 +198,13 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
if(0)
{
LogRangeAsType<double>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<double>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<double>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
}
......
......@@ -387,17 +387,17 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -413,7 +413,7 @@ struct MfmaSelector
static constexpr auto GetMfma();
template <>
static constexpr auto GetMfma<double, 32, 32>()
static constexpr auto GetMfma<double, 16, 16>()
{
return MfmaInstr::mfma_f64_16x16x4f64;
}
......
......@@ -298,13 +298,18 @@ template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;
template <>
struct intrin_mfma_f64_16x16x4f64<32, 32>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{
#ifdef __gxf90a__
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
reg_c.template AsType<double4_t>()(Number<0>{}) = {reg_a, reg_a, reg_b, reg_b};
#endif
}
};
} // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment