Commit 0b914465 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed wmma

parent 5db68230
......@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
#-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......
......@@ -34,24 +34,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp,
CElementOp,
GemmDefault,
1, // Prefetch stage
32, // BlockSize
16, // MPerBlock
16, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
1, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
1, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 8, 1>,
2, // Prefetch stage
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWmma
16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 8, 1>,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
......@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true,
1, // C shuffle (M Repeat) Per store
1, // C shuffle (N Repeat) Per store
S<1, 16, 1, 2>,
S<1, 32, 1, 8>,
8>;
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -108,7 +108,7 @@ struct BlockwiseGemmWMMA
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0);
return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0);
}
else
{
......@@ -125,7 +125,7 @@ struct BlockwiseGemmWMMA
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0);
return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0);
}
else
{
......@@ -300,6 +300,9 @@ struct BlockwiseGemmWMMA
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
static_assert(KPack % (A_K1 * A_KRow) == 0, "");
static_assert(KPack % (B_K1 * B_KRow) == 0, "");
// basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat)
{
......@@ -309,7 +312,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......@@ -319,7 +322,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -365,7 +368,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / B_K1>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
......@@ -373,7 +376,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
make_tuple(Number<k * KPack / A_K1>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0),
......@@ -416,7 +419,7 @@ struct BlockwiseGemmWMMA
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, Number<MRepeat>{}, I1, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<A_KRow * A_K1>{},
Number<KPack / A_KRow>{},
Number<A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
......@@ -425,7 +428,7 @@ struct BlockwiseGemmWMMA
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<B_KRow * B_K1>{},
Number<KPack / B_KRow>{},
Number<B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
......
......@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = (K1 == 16) ? 32 : 16;
static constexpr auto WmmaK = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -841,10 +841,6 @@ struct GridwiseGemm_Wmma
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
static_assert(MSubGroup == 2, "");
static_assert(NThreadPerSubGroup == 16, "");
static_assert(MAccVgprs == 8, "");
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
......
......@@ -129,12 +129,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
{
// Absolute fixing property
// * Data Pixel
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
......@@ -145,9 +145,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static constexpr index_t num_acc_vgprs_per_wave =
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -390,10 +389,12 @@ struct WmmaSelector
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
#if 0
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
#endif
}
};
......@@ -443,8 +444,6 @@ struct WmmaGemm
const auto NWave =
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
static_assert(wmma_instr.num_acc_vgprs_per_wave == 8, "");
return transform_tensor_descriptor(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(
......@@ -553,6 +552,9 @@ struct WmmaGemm
__device__ static auto GetSubGroupId()
{
static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
wmma_instr.wave_size,
"");
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
}
......@@ -567,13 +569,11 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
// return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
return GetLaneIdUnderSubGroup();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
// return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
return GetLaneIdUnderSubGroup();
}
......
......@@ -156,7 +156,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
// if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
......@@ -10,7 +10,7 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx1200" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment