Commit 14b422d7 authored by Jing Zhang's avatar Jing Zhang
Browse files

debugging

parent f221c68e
......@@ -21,49 +21,47 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
< ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
128, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>,
8>;
// clang-format on
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
32, // BlockSize
16, // MPerBlock
16, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
1, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 8, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 8, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 16, 1, 2>,
8>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
......@@ -73,12 +73,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break;
case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
break;
case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
......
......@@ -453,6 +453,7 @@ struct BlockwiseGemmWMMA
A_K1>;
};
#if 0
template <>
struct AThreadCopySelector<false>
{
......@@ -467,6 +468,7 @@ struct BlockwiseGemmWMMA
5,
A_K1>;
};
#endif
template <bool EnableLds>
struct BThreadCopySelector;
......@@ -486,6 +488,7 @@ struct BlockwiseGemmWMMA
B_K1>;
};
#if 0
template <>
struct BThreadCopySelector<false>
{
......@@ -500,6 +503,7 @@ struct BlockwiseGemmWMMA
5,
B_K1>;
};
#endif
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
......
......@@ -97,8 +97,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto AEnableLds =
true; // AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds =
true; // BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
......@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
......@@ -158,6 +158,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
}
else if constexpr(wave_size == 64)
{
static_assert(1, "");
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment