Commit 14b422d7 authored by Jing Zhang's avatar Jing Zhang
Browse files

debugging

parent f221c68e
...@@ -21,9 +21,8 @@ using CElementOp = PassThrough; ...@@ -21,9 +21,8 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle<
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ALayout,
< ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -36,23 +35,23 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -36,23 +35,23 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
CElementOp, CElementOp,
GemmDefault, GemmDefault,
1, // Prefetch stage 1, // Prefetch stage
128, // BlockSize 32, // BlockSize
64, // MPerBlock 16, // MPerBlock
128, // NPerBlock 16, // NPerBlock
64, // KPerBlock 64, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 1, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave 1, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>, S<4, 8, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
8, 8,
8, 8,
true, true,
S<4, 32, 1>, S<4, 8, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
2, 2,
...@@ -61,9 +60,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -61,9 +60,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true, true,
1, // C shuffle (M Repeat) Per store 1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store 1, // C shuffle (N Repeat) Per store
S<1, 32, 1, 4>, S<1, 16, 1, 2>,
8>; 8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
......
...@@ -73,12 +73,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -73,12 +73,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
break; break;
case 3: case 3:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
break; break;
case 5: case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
......
...@@ -453,6 +453,7 @@ struct BlockwiseGemmWMMA ...@@ -453,6 +453,7 @@ struct BlockwiseGemmWMMA
A_K1>; A_K1>;
}; };
#if 0
template <> template <>
struct AThreadCopySelector<false> struct AThreadCopySelector<false>
{ {
...@@ -467,6 +468,7 @@ struct BlockwiseGemmWMMA ...@@ -467,6 +468,7 @@ struct BlockwiseGemmWMMA
5, 5,
A_K1>; A_K1>;
}; };
#endif
template <bool EnableLds> template <bool EnableLds>
struct BThreadCopySelector; struct BThreadCopySelector;
...@@ -486,6 +488,7 @@ struct BlockwiseGemmWMMA ...@@ -486,6 +488,7 @@ struct BlockwiseGemmWMMA
B_K1>; B_K1>;
}; };
#if 0
template <> template <>
struct BThreadCopySelector<false> struct BThreadCopySelector<false>
{ {
...@@ -500,6 +503,7 @@ struct BlockwiseGemmWMMA ...@@ -500,6 +503,7 @@ struct BlockwiseGemmWMMA
5, 5,
B_K1>; B_K1>;
}; };
#endif
typename AThreadCopySelector<AEnableLds>::type a_thread_copy_; typename AThreadCopySelector<AEnableLds>::type a_thread_copy_;
typename BThreadCopySelector<BEnableLds>::type b_thread_copy_; typename BThreadCopySelector<BEnableLds>::type b_thread_copy_;
......
...@@ -97,8 +97,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -97,8 +97,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static constexpr auto AEnableLds_manu = false; static constexpr auto AEnableLds_manu = false;
static constexpr auto BEnableLds_manu = false; static constexpr auto BEnableLds_manu = false;
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); static constexpr auto AEnableLds =
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); true; // AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static constexpr auto BEnableLds =
true; // BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
......
...@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety // Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{}; static constexpr index_t wave_size = Number<WaveSize>{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x // * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction // * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction // * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave = static constexpr index_t num_acc_vgprs_per_wave =
...@@ -158,6 +158,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12, ...@@ -158,6 +158,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
} }
else if constexpr(wave_size == 64) else if constexpr(wave_size == 64)
{ {
static_assert(1, "");
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment