Commit 74f0d5de authored by aska-0096's avatar aska-0096
Browse files

save debugging progress

parent 5df713ef
......@@ -125,12 +125,12 @@ using DeviceGemmInstance =
S<4, 64, 1>, // B1BlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
1,
2,
8,
8,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
......
......@@ -117,6 +117,26 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 ; unit: a b0 fail
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: b0 ; unit: a b1 fail
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a ; unit: b0 b1 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
......@@ -220,6 +240,12 @@ int run(int argc, char* argv[])
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// for(int i =0; i< 128; i++){
// for(int j =0; j< 128; j++){
// printf("%0.2lf ", acc0_g_m_n.mData[i*128 +j]);
// }
// printf("\n");
// }
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
......
......@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim.y,
block_dim.z);
const int nrepeat = 10;
const int nrepeat = 1;
printf("Warm up 1 time\n");
// printf("Warm up 1 time\n");
// warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
printf("Start running %d times...\n", nrepeat);
......
......@@ -16,19 +16,36 @@ template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
typename ABlockDesc,
typename BBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
/* A: K0PerBlock x MPerBlock x K1
index_t KPack,
bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
struct BlockwiseGemmWMMA
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -42,18 +59,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
......@@ -79,6 +88,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
......@@ -129,23 +139,63 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return make_tuple(c_thread_m, c_thread_n);
}
// using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
// __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle(
// Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
// Tuple4 b_origin = CalculateBThreadOriginDataIndex())
// : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(Number<m0>{},
blk_idx[I0],
waveId_m,
Number<n0>{},
waveId_n,
blk_idx[I1],
blk_idx[I2]);
}
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
// printf("tid %03d, Mat-B offset %d\n", get_thread_local_1d_id()%32, CalculateBThreadOriginDataIndex().At(Number<3>{}));
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
I1,
Number<NRepeat>{},
I1,
I1,
NAccVgprs));
}
// Thread level, register decriptor. Vector-write
......@@ -171,9 +221,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
MAccVgprs));
}
// Provide dimension size
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
......@@ -184,37 +256,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
.MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
......@@ -235,6 +301,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0),
a_thread_buf);
// static_for<0, a_thread_buf.size(), 1>{}([&](auto i) {
// a_thread_buf(i) = 1;
// });
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
......@@ -254,6 +323,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
// a_thread_vec.template AsType<FloatA>()(i) = 1;
// b_thread_vec.template AsType<FloatB>()(i) = 1;
});
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
......@@ -262,6 +334,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// printf("GPU Gemm0 input, Tid %03d, A%2d = %04x, B%2d = %0x4\n",
// get_thread_local_1d_id(),
// i.value, *(reinterpret_cast<uint16_t*>(&a_thread_vec.template AsType<FloatA>()(i))),
// i.value, *(reinterpret_cast<uint16_t*>(&b_thread_vec.template AsType<FloatB>()(i))));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
......@@ -304,10 +382,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
// block wise level pipe designed for inline asm
......@@ -601,7 +677,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize());
// TODO: Fix it, MRepeat < NRepeat
constexpr auto RepeatDiff = MRepeat - NRepeat;
// Read all Mrepeat, Nrepeat
static_for<0, NRepeat, 1>{}([&](auto iN) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
......
......@@ -145,11 +145,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0Spec,
B1Spec,
CSpec>;
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
......@@ -167,13 +162,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<K1>{});
}
static auto
MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
b1_gs_ns_ls_strides_vec),
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec),
Number<L1>{});
}
......@@ -462,8 +455,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
GridwiseOp,
ADataType,
......@@ -482,7 +473,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
typename GridwiseOp::DefaultBlock2CTileMap,
has_main_loop>;
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -754,11 +745,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerBlock << ", "
<< NPerWMMA << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< L0PerBlock << ", "
<< L1
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec)
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
......
......@@ -190,22 +190,89 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1Value should be Number<...>
static constexpr auto AK0 = Number<K0PerBlock>{};
static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<K0PerBlock>{};
static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto L0 = Number<L0PerBlock>{};
static constexpr auto L1 = Number<L1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerWmma * MRepeat);
static constexpr auto Gemm0LWaves = L0PerBlock * L1Value / (LPerWmma * LRepeat);
static constexpr auto AL0 = Number<L0PerBlock / 2>{};
static constexpr auto AL1 = Number<L1Value>{};
static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1(const A0BlockDesc_AK0_M_AK1&)
{
constexpr index_t A_K0 = A0BlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = A0BlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor(
A0BlockDesc_AK0_M_AK1{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B0BlockDesc_BK0_L_BK1>
__host__ __device__ static constexpr auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&)
{
constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0);
constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2);
constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma);
return transform_tensor_descriptor(
B0BlockDesc_BK0_L_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename A1BlockDesc_AL0_M_AL1>
__host__ __device__ static constexpr auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
{
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{},
make_tuple(make_pass_through_transform(Number<A_L0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, I1, I1)),
make_pass_through_transform(Number<A_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B1BlockDesc_BL0_N_BL1>
__host__ __device__ static constexpr auto MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
{
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor(
B1BlockDesc_BL0_N_BL1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
......@@ -226,8 +293,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(L0, Number<NPerBlock>{}, L1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * L1, L1, I1));
make_tuple(BL0, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * BL1, BL1, I1));
}
__host__ __device__ static constexpr auto
......@@ -374,7 +441,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto b1_block_desc_bl0_n_bl1 =
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), L1);
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -451,7 +518,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// constexpr auto max_lds_align = K1Value;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
......@@ -491,7 +558,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
B0ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
Sequence<BK0, LPerBlock, BK1>,
B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder,
FloatB0,
......@@ -520,23 +587,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK);
auto blockwise_gemm0 =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockSize,
FloatA,
FloatB0,
FloatAcc0,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b0_block_desc_k0perblock_lperblock_k1),
decltype(MakeA0BlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)),
decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)),
MPerBlock,
LPerBlock,
K0PerBlock * K1Value,
MPerWmma,
LPerWmma,
MRepeat,
LRepeat,
KPack>{};
KPack,
true>{}; // C' = B' x A'
// Prepare Register for A*B0 matrix
auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
......@@ -550,7 +621,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lrepeat, lsubgroup)),
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)),
make_pass_through_transform(laccvgprs)),
make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}),
......@@ -587,7 +658,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4);
constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5);
constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6);
if(get_thread_local_1d_id()==0){
printf("t_mrepeat %d, t_mwave %d, t_mthreadpersubgroup %d, t_lrepeat %d, t_lwave %d, t_lsubgroup %d, t_laccvgprs %d \n",
t_mrepeat.value,
t_mwave.value,
t_mthreadpersubgroup.value,
t_lrepeat.value,
t_lwave.value,
t_lsubgroup.value,
t_laccvgprs.value);
}
// get acc0 thread map
constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)),
......@@ -628,11 +708,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(L0PerBlock, 0, 0);
constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0);
// A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
Number<L0PerBlock * L1Value / laccvgprs>{},
Number<AL0 * AL1 / laccvgprs>{},
Number<mrepeat * mwave * mthreadpersubgroup>{},
Number<laccvgprs>{}); // Data duplicated dimension
......@@ -665,10 +745,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation,
B1ElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<L0, NPerBlock, L1>,
Sequence<BL0, NPerBlock, BL1>,
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
FloatB1,
......@@ -700,22 +780,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b1_block_desc_l0perblock_nperblock_l1.GetElementSpaceSize());
auto blockwise_gemm1 =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
BlockwiseGemmWMMA<BlockSize,
FloatA,
FloatB1,
FloatAcc1,
decltype(a1_thread_desc_l0perblock_mperblock_l1),
decltype(b1_block_desc_l0perblock_nperblock_l1),
decltype(MakeA1BlockDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)),
MPerBlock,
NPerBlock,
BL0 * BL1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
KPack>{make_tuple(0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock;
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (L0PerBlock * L1Value);
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (BL0 * BL1);
// Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf;
......@@ -811,13 +894,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds();
// gemm0 end
// gemm0 incorrect
// Tiled softmax start
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
// printf("GPU Gemm 0, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// static_for<0, acc0_thread_buf.Size(), 1>{}([&](auto i) {
// printf("GPU Gemm0, Tid %03d, GPU acc%d = %lf\n", get_thread_local_1d_id(), i.value, acc0_thread_buf[i]);
// });
blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
// printf("GPU SoftMax, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
......@@ -862,6 +949,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds();
// printf("GPU permute lanex, Tid %03d, GPU 0 = %04x\n", get_thread_local_1d_id(), *(reinterpret_cast<const uint16_t*>(&a1_thread_buf[I0])));
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
......@@ -934,11 +1023,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// write out to C, implement shuffle
{
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
......@@ -973,7 +1062,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm0.CalculateCThreadOriginDataIndex(I0, I0);
const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......
......@@ -140,6 +140,39 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t A_K0 = ABlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = ABlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor(
ABlockDesc_AK0_M_AK1{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr index_t B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
......@@ -414,12 +447,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
BlockwiseGemmWMMA<BlockSize,
FloatA,
FloatB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
decltype(MakeABlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)),
decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc_k0perblock_nperblock_k1)),
MPerBlock,
NPerBlock,
K0PerBlock * K1,
MPerWmma,
NPerWmma,
MRepeat,
......
......@@ -1382,6 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
......@@ -1396,13 +1397,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
if(get_thread_local_1d_id() % 32 > 16){
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset + dst_buf.size()/2>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset + dst_buf.size()/2>{})),
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset + DstScalarPerVector>{})),
type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
}
else{
// apply type convert
dst_buf(Number<dst_offset + dst_buf.size()/2>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset>{})),
type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
......
......@@ -517,12 +517,12 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
return GetSwizzledLaneIdLow();
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
return GetLaneIdUnderSubGroup();
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
}
__device__ static CIndex GetBeginOfThreadBlk()
......
......@@ -55,6 +55,18 @@ struct GeneratorTensor_1<int8_t>
}
};
template <typename T>
struct GeneratorTensor_dec1
{
T value = 0.1;
template <typename... Is>
T operator()(Is...)
{
return value;
}
};
template <typename T>
struct GeneratorTensor_2
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment