"...composable_kernel_rocm.git" did not exist on "bb4b82a95a27248c713ed93cd4b47384663eefca"
Commit 74f0d5de authored by aska-0096's avatar aska-0096
Browse files

save debugging progress

parent 5df713ef
...@@ -39,8 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -39,8 +39,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using ADataType = F16;
using B0DataType = F16; using B0DataType = F16;
using B1DataType = F16; using B1DataType = F16;
using Acc0DataType = F32; using Acc0DataType = F32;
using Acc1DataType = F32; using Acc1DataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
...@@ -125,12 +125,12 @@ using DeviceGemmInstance = ...@@ -125,12 +125,12 @@ using DeviceGemmInstance =
S<4, 64, 1>, // B1BlockTransfer S<4, 64, 1>, // B1BlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
1, 2,
8, 8,
8, 8,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNWmmaPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
......
...@@ -117,6 +117,26 @@ int run(int argc, char* argv[]) ...@@ -117,6 +117,26 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break; break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 ; unit: a b0 fail
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: b0 ; unit: a b1 fail
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a ; unit: b0 b1 pass
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
...@@ -220,6 +240,12 @@ int run(int argc, char* argv[]) ...@@ -220,6 +240,12 @@ int run(int argc, char* argv[])
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// for(int i =0; i< 128; i++){
// for(int j =0; j< 128; j++){
// printf("%0.2lf ", acc0_g_m_n.mData[i*128 +j]);
// }
// printf("\n");
// }
// masking // masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(N);
......
...@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim.y, block_dim.y,
block_dim.z); block_dim.z);
const int nrepeat = 10; const int nrepeat = 1;
printf("Warm up 1 time\n"); // printf("Warm up 1 time\n");
// warm up // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); // kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
......
...@@ -16,19 +16,36 @@ template <index_t BlockSize, ...@@ -16,19 +16,36 @@ template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatAcc, typename FloatAcc,
typename AK0MK1BlockDesc, typename ABlockDesc,
typename BK0NK1BlockDesc, typename BBlockDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWMMA, index_t MPerWMMA,
index_t NPerWMMA, index_t NPerWMMA,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
/* A: K0PerBlock x MPerBlock x K1 bool TransposeC = false>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1 * B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16 * KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/ */
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle struct BlockwiseGemmWMMA
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -42,18 +59,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -42,18 +59,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32; static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I4);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I4);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); static constexpr auto wmma_gemm = WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack, TransposeC>{};
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm =
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
...@@ -79,6 +88,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -79,6 +88,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
} }
// Default, Block buffer in LDS, thread level offset enabled
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
...@@ -129,23 +139,63 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -129,23 +139,63 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
// using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); template <index_t m0, index_t n0>
// __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle( __device__ static auto CalculateCThreadOriginDataIndex7D(Number<m0>, Number<n0>)
// Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
// Tuple4 b_origin = CalculateBThreadOriginDataIndex())
// : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && const auto wave_idx = GetWaveIdx();
BK0NK1BlockDesc::IsKnownAtCompileTime(),
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D();
return make_tuple(Number<m0>{},
blk_idx[I0],
waveId_m,
Number<n0>{},
waveId_n,
blk_idx[I1],
blk_idx[I2]);
}
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmWMMA(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!"); "wrong!");
// printf("tid %03d, Mat-B offset %d\n", get_thread_local_1d_id()%32, CalculateBThreadOriginDataIndex().At(Number<3>{}));
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
I1,
Number<NRepeat>{},
I1,
I1,
NAccVgprs));
} }
// Thread level, register decriptor. Vector-write // Thread level, register decriptor. Vector-write
...@@ -171,9 +221,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -171,9 +221,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
MAccVgprs)); MAccVgprs));
} }
// Provide dimension size template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{ {
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
...@@ -184,37 +256,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -184,37 +256,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
Number<NPerWMMA>{})); Number<NPerWMMA>{}));
return wmma_gemm return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() // Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{ {
return transform_tensor_descriptor( constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
AK0MK1BlockDesc{}, make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
make_tuple(make_pass_through_transform(Number<A_K0>{}), Number<MWaves>{},
make_unmerge_transform( Number<MPerWMMA>{},
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})), Number<NRepeat>{},
make_pass_through_transform(Number<A_K1>{})), Number<NWaves>{},
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), Number<NPerWMMA>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() return wmma_gemm
{ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
return transform_tensor_descriptor( c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1;
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1;
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
...@@ -235,6 +301,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -235,6 +301,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// static_for<0, a_thread_buf.size(), 1>{}([&](auto i) {
// a_thread_buf(i) = 1;
// });
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
...@@ -254,6 +323,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -254,6 +323,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
b_thread_vec.template AsType<FloatB>()(i) = b_thread_vec.template AsType<FloatB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}]; make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
// a_thread_vec.template AsType<FloatA>()(i) = 1;
// b_thread_vec.template AsType<FloatB>()(i) = 1;
}); });
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type; using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
...@@ -261,6 +333,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -261,6 +333,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// printf("GPU Gemm0 input, Tid %03d, A%2d = %04x, B%2d = %0x4\n",
// get_thread_local_1d_id(),
// i.value, *(reinterpret_cast<uint16_t*>(&a_thread_vec.template AsType<FloatA>()(i))),
// i.value, *(reinterpret_cast<uint16_t*>(&b_thread_vec.template AsType<FloatB>()(i))));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
...@@ -304,10 +382,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -304,10 +382,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1, B_K1,
B_K1>; B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; BThreadCopy b_thread_copy_;
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
}; };
// block wise level pipe designed for inline asm // block wise level pipe designed for inline asm
...@@ -601,7 +677,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO ...@@ -601,7 +677,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
// TODO: Fix it, MRepeat < NRepeat
constexpr auto RepeatDiff = MRepeat - NRepeat; constexpr auto RepeatDiff = MRepeat - NRepeat;
// Read all Mrepeat, Nrepeat // Read all Mrepeat, Nrepeat
static_for<0, NRepeat, 1>{}([&](auto iN) { static_for<0, NRepeat, 1>{}([&](auto iN) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
......
...@@ -145,14 +145,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -145,14 +145,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0Spec, B0Spec,
B1Spec, B1Spec,
CSpec>; CSpec>;
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
...@@ -160,20 +155,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -160,20 +155,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
} }
static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec, static auto MakeB0GridDescriptor_BK0_L_BK1(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec) const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, b0_gs_ls_ks_strides_vec),
Number<K1>{}); Number<K1>{});
} }
static auto static auto MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
MakeB1GridDescriptor_BL0_N_BL1(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec, const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, b1_gs_ns_ls_strides_vec),
b1_gs_ns_ls_strides_vec),
Number<L1>{}); Number<L1>{});
} }
...@@ -462,8 +455,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -462,8 +455,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle< const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
GridwiseOp, GridwiseOp,
ADataType, ADataType,
...@@ -482,7 +473,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -482,7 +473,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
typename GridwiseOp::DefaultBlock2CTileMap, typename GridwiseOp::DefaultBlock2CTileMap,
has_main_loop>; has_main_k_block_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -754,11 +745,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -754,11 +745,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< K0PerBlock << ", " << K0PerBlock << ", "
<< K1 << ", " << K1 << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerWMMA << ", "
<< MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< L0PerBlock << ", " << L0PerBlock << ", "
<< L1 << L1
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec)
<< ">" << ">"
<< " NumPrefetch: " << " NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
......
...@@ -190,22 +190,89 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -190,22 +190,89 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1Value should be Number<...>
static constexpr auto AK0 = Number<K0PerBlock>{}; static constexpr auto AK0 = Number<K0PerBlock>{};
static constexpr auto AK1 = Number<K1Value>{}; static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<K0PerBlock>{}; static constexpr auto BK0 = Number<K0PerBlock>{};
static constexpr auto BK1 = Number<K1Value>{}; static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto L0 = Number<L0PerBlock>{};
static constexpr auto L1 = Number<L1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerWmma * MRepeat); static constexpr auto AL0 = Number<L0PerBlock / 2>{};
static constexpr auto Gemm0LWaves = L0PerBlock * L1Value / (LPerWmma * LRepeat); static constexpr auto AL1 = Number<L1Value>{};
static constexpr auto BL0 = Number<L0PerBlock>{};
static constexpr auto BL1 = Number<L1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename A0BlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1(const A0BlockDesc_AK0_M_AK1&)
{
constexpr index_t A_K0 = A0BlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = A0BlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor(
A0BlockDesc_AK0_M_AK1{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B0BlockDesc_BK0_L_BK1>
__host__ __device__ static constexpr auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1(const B0BlockDesc_BK0_L_BK1&)
{
constexpr index_t B_K0 = B0BlockDesc_BK0_L_BK1{}.GetLength(I0);
constexpr index_t B_K1 = B0BlockDesc_BK0_L_BK1{}.GetLength(I2);
constexpr index_t LWaves = LPerBlock / (LRepeat * LPerWmma);
return transform_tensor_descriptor(
B0BlockDesc_BK0_L_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename A1BlockDesc_AL0_M_AL1>
__host__ __device__ static constexpr auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
{
constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
return transform_tensor_descriptor(
A1BlockDesc_AL0_M_AL1{},
make_tuple(make_pass_through_transform(Number<A_L0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, I1, I1)),
make_pass_through_transform(Number<A_L1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename B1BlockDesc_BL0_N_BL1>
__host__ __device__ static constexpr auto MakeB1BlockDescriptor_L0_N0_N1_N2_L1(const B1BlockDesc_BL0_N_BL1&)
{
constexpr index_t B_K0 = B1BlockDesc_BL0_N_BL1{}.GetLength(I0);
constexpr index_t B_K1 = B1BlockDesc_BL0_N_BL1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor(
B1BlockDesc_BL0_N_BL1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
...@@ -226,8 +293,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -226,8 +293,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(L0, Number<NPerBlock>{}, L1), make_tuple(BL0, Number<NPerBlock>{}, BL1),
make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * L1, L1, I1)); make_tuple(Number<NPerBlock + B1BlockLdsExtraN>{} * BL1, BL1, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -374,7 +441,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -374,7 +441,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static constexpr auto b1_block_desc_bl0_n_bl1 = static constexpr auto b1_block_desc_bl0_n_bl1 =
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1(); GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), L1); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -451,8 +518,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -451,8 +518,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// constexpr auto max_lds_align = K1Value;
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1(); constexpr auto b0_block_desc_k0perblock_lperblock_k1 = GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
...@@ -491,7 +558,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -491,7 +558,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
B0ElementwiseOperation, B0ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, LPerBlock, BK1>,
B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferThreadClusterArrangeOrder,
FloatB0, FloatB0,
...@@ -520,23 +587,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -520,23 +587,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto WmmaK = 16; constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK);
auto blockwise_gemm0 = auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockSize,
FloatA, FloatA,
FloatB0, FloatB0,
FloatAcc0, FloatAcc0,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(MakeA0BlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)),
decltype(b0_block_desc_k0perblock_lperblock_k1), decltype(MakeB0BlockDescriptor_K0_L0_L1_L2_K1(b0_block_desc_k0perblock_lperblock_k1)),
MPerWmma, MPerBlock,
LPerWmma, LPerBlock,
MRepeat, K0PerBlock * K1Value,
LRepeat, MPerWmma,
KPack>{}; LPerWmma,
MRepeat,
LRepeat,
KPack,
true>{}; // C' = B' x A'
// Prepare Register for A*B0 matrix // Prepare Register for A*B0 matrix
auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer();
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
...@@ -550,7 +621,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -550,7 +621,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lrepeat, lsubgroup)), make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)), make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)),
make_pass_through_transform(laccvgprs)), make_pass_through_transform(laccvgprs)),
make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}),
...@@ -564,9 +635,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -564,9 +635,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize()); b0_block_desc_k0perblock_lperblock_k1.GetElementSpaceSize());
// Shift Per SUB_K // Shift Per SUB_K
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b0_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0); const auto a_block_reset_copy_step = make_multi_index(-a_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0); const auto b0_block_reset_copy_step = make_multi_index(-b0_grid_desc_k0_l_k1.GetLength(I0), LPerBlock, 0);
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
...@@ -587,7 +658,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -587,7 +658,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4); constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4);
constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5); constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5);
constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6); constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6);
if(get_thread_local_1d_id()==0){
printf("t_mrepeat %d, t_mwave %d, t_mthreadpersubgroup %d, t_lrepeat %d, t_lwave %d, t_lsubgroup %d, t_laccvgprs %d \n",
t_mrepeat.value,
t_mwave.value,
t_mthreadpersubgroup.value,
t_lrepeat.value,
t_lwave.value,
t_lsubgroup.value,
t_laccvgprs.value);
}
// get acc0 thread map // get acc0 thread map
constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor( constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)), make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)),
...@@ -628,11 +708,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -628,11 +708,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/ /*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1(); constexpr auto b1_block_desc_l0perblock_nperblock_l1 = GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1();
constexpr auto b1_block_slice_copy_step = make_multi_index(L0PerBlock, 0, 0); constexpr auto b1_block_slice_copy_step = make_multi_index(BL0, 0, 0);
// A1 matrix in VGPR // A1 matrix in VGPR
constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
Number<L0PerBlock * L1Value / laccvgprs>{}, Number<AL0 * AL1 / laccvgprs>{},
Number<mrepeat * mwave * mthreadpersubgroup>{}, Number<mrepeat * mwave * mthreadpersubgroup>{},
Number<laccvgprs>{}); // Data duplicated dimension Number<laccvgprs>{}); // Data duplicated dimension
...@@ -665,10 +745,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -665,10 +745,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy // B1 matrix blockwise copy
auto b1_blockwise_copy = auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
B0ElementwiseOperation, B1ElementwiseOperation,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<L0, NPerBlock, L1>, Sequence<BL0, NPerBlock, BL1>,
B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
FloatB1, FloatB1,
...@@ -700,22 +780,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -700,22 +780,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b1_block_desc_l0perblock_nperblock_l1.GetElementSpaceSize()); b1_block_desc_l0perblock_nperblock_l1.GetElementSpaceSize());
auto blockwise_gemm1 = auto blockwise_gemm1 =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA<BlockSize,
FloatA, FloatA,
FloatB1, FloatB1,
FloatAcc1, FloatAcc1,
decltype(a1_thread_desc_l0perblock_mperblock_l1), decltype(MakeA1BlockDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
decltype(b1_block_desc_l0perblock_nperblock_l1), decltype(MakeB1BlockDescriptor_L0_N0_N1_N2_L1(b1_block_desc_l0perblock_nperblock_l1)),
MPerWmma, MPerBlock,
NPerWmma, NPerBlock,
MRepeat, BL0 * BL1,
NRepeat, MPerWmma,
KPack>{}; NPerWmma,
MRepeat,
NRepeat,
KPack>{make_tuple(0, 0, 0, 0, 0)};
auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock; const index_t num_gemm1_l_block_outer_loop = b0_grid_desc_k0_l_k1.GetLength(I1) / LPerBlock;
constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (L0PerBlock * L1Value); constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / (BL0 * BL1);
// Initialize C // Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, FloatAcc1, acc1_thread_buf.Size(), true> c_thread_buf;
...@@ -809,15 +892,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -809,15 +892,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
} }
block_sync_lds(); block_sync_lds();
// gemm0 end // gemm0 end
// gemm0 incorrect
// Tiled softmax start // Tiled softmax start
// softmax // softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
// printf("GPU Gemm 0, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// static_for<0, acc0_thread_buf.Size(), 1>{}([&](auto i) {
// printf("GPU Gemm0, Tid %03d, GPU acc%d = %lf\n", get_thread_local_1d_id(), i.value, acc0_thread_buf[i]);
// });
blockwise_softmax.Run(acc0_thread_buf, workspace_buf); blockwise_softmax.Run(acc0_thread_buf, workspace_buf);
// printf("GPU SoftMax, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
...@@ -862,6 +949,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -862,6 +949,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds(); block_sync_lds();
// printf("GPU permute lanex, Tid %03d, GPU 0 = %04x\n", get_thread_local_1d_id(), *(reinterpret_cast<const uint16_t*>(&a1_thread_buf[I0])));
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds(); block_sync_lds();
...@@ -934,11 +1023,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -934,11 +1023,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// write out to C, implement shuffle // write out to C, implement shuffle
{ {
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
// This API Provide All dimension (size) you need // This API Provide All dimension (size) you need
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
...@@ -973,7 +1062,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle ...@@ -973,7 +1062,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm0.CalculateCThreadOriginDataIndex(I0, I0); const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
......
...@@ -140,6 +140,39 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -140,6 +140,39 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t A_K0 = ABlockDesc_AK0_M_AK1{}.GetLength(I0);
constexpr index_t A_K1 = ABlockDesc_AK0_M_AK1{}.GetLength(I2);
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
return transform_tensor_descriptor(
ABlockDesc_AK0_M_AK1{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t B_K0 = BBlockDesc_BK0_N_BK1{}.GetLength(I0);
constexpr index_t B_K1 = BBlockDesc_BK0_N_BK1{}.GetLength(I2);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
return transform_tensor_descriptor(
BBlockDesc_BK0_N_BK1{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -414,17 +447,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -414,17 +447,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize, BlockwiseGemmWMMA<BlockSize,
FloatA, FloatA,
FloatB, FloatB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1), decltype(MakeABlockDescriptor_K0_M0_M1_M2_K1(a_block_desc_k0perblock_mperblock_k1)),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(MakeBBlockDescriptor_K0_N0_N1_N2_K1(b_block_desc_k0perblock_nperblock_k1)),
MPerWmma, MPerBlock,
NPerWmma, NPerBlock,
MRepeat, K0PerBlock * K1,
NRepeat, MPerWmma,
KPack>{}; NPerWmma,
MRepeat,
NRepeat,
KPack>{};
// Prepare Register for C matrix // Prepare Register for C matrix
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
......
...@@ -1382,6 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1382,6 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector // copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
...@@ -1396,13 +1397,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ...@@ -1396,13 +1397,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
if(get_thread_local_1d_id() % 32 > 16){ if(get_thread_local_1d_id() % 32 > 16){
// apply type convert // apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset + dst_buf.size()/2>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset + dst_buf.size()/2>{})), dst_buf(Number<dst_offset + DstScalarPerVector>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset + DstScalarPerVector>{})),
type_convert<DstData>(v), type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0); LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
} }
else{ else{
// apply type convert // apply type convert
dst_buf(Number<dst_offset + dst_buf.size()/2>{}) = type_convert<DstData>(v); dst_buf(Number<dst_offset + DstScalarPerVector>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset>{})), dst_buf(Number<dst_offset>{}) = __builtin_amdgcn_permlanex16(type_convert<DstData>(dst_buf(Number<dst_offset>{})),
type_convert<DstData>(v), type_convert<DstData>(v),
LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0); LowEightRowlaneIdx, HighEightRowLaneIdx, 1, 0);
......
...@@ -517,12 +517,12 @@ struct WmmaGemm ...@@ -517,12 +517,12 @@ struct WmmaGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex() __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
return GetSwizzledLaneIdLow(); return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
} }
__host__ __device__ static auto CalculateBThreadOriginDataIndex() __host__ __device__ static auto CalculateBThreadOriginDataIndex()
{ {
return GetLaneIdUnderSubGroup(); return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
} }
__device__ static CIndex GetBeginOfThreadBlk() __device__ static CIndex GetBeginOfThreadBlk()
......
...@@ -55,6 +55,18 @@ struct GeneratorTensor_1<int8_t> ...@@ -55,6 +55,18 @@ struct GeneratorTensor_1<int8_t>
} }
}; };
template <typename T>
struct GeneratorTensor_dec1
{
T value = 0.1;
template <typename... Is>
T operator()(Is...)
{
return value;
}
};
template <typename T> template <typename T>
struct GeneratorTensor_2 struct GeneratorTensor_2
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment