Commit 3f9dbcac authored by coderfeli's avatar coderfeli
Browse files

use new pipeline for b preshuffle, run ok; revert olds to fix ckprofiler

parent 54f44e62
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
...@@ -27,8 +27,6 @@ using S = ck::Sequence<Is...>; ...@@ -27,8 +27,6 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
// using I8 = int8_t;
// using I32 = int;
using F16 = ck::half_t; using F16 = ck::half_t;
using FP8 = ck::f8_t; using FP8 = ck::f8_t;
using F32 = float; using F32 = float;
...@@ -79,109 +77,6 @@ struct MultiplyMultiply ...@@ -79,109 +77,6 @@ struct MultiplyMultiply
}; };
// struct MultiplyMultiply
// {
// template <typename E, typename C, typename D0, typename D1>
// __host__ __device__ constexpr void
// operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
// ck::half_t& e, const float& c, const float& d0, const float& d1) const
// {
// const float x0_f = c * d0 * d1;
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
// ck::half_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
// ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::bhalf_t>(x0_f);
// }
// };
// void reinit2(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// int kinit = 0;
// for (int k = 0; k < K; k+=1) {
// // dst[n * K + k] = n;
// if(k>0 && k%128==0){
// kinit += 1;
// }
// dst[n * K + k] = k % 128 + kinit;//rand() % 5 - 2;
// }
// }
// }
// void reinit(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; k+=1) {
// dst[n * K + k] = ck::type_convert<FP8>(float(1));
// }
// }
// }
void dump(FP8* dst, int N, int K) {
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
printf("%.1f,", ck::type_convert<float>(dst[n * K + k]));
}
printf("\n");
}
}
// void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
// const int NRepeat = 1;
// const int KRepeat = 8;
// const int NWave = 4;
// const int KLane = 2;
// const int NLane = 32;
// const int KPack = 16;
// int K0 = K / (KRepeat * KLane * KPack);
// int tempn, tempk;
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; ++k) {
// int n0 = n / (NRepeat * NLane * NWave);
// int k0 = k / (KRepeat * KLane * KPack);
// tempn = n % (NRepeat * NLane * NWave);
// tempk = k % (KRepeat * KLane * KPack);
// int n1 = tempn / (NLane * NWave);
// int k1 = tempk / (KLane * KPack);
// tempn = tempn % (NLane * NWave);
// tempk = tempk % (KLane * KPack);
// int n2 = tempn / NLane;
// int k2 = tempk / KPack;
// int n3 = tempn % NLane;
// int k3 = tempk % KPack;
// int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
// + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
// + n1 * KPack * NLane * KLane * NWave * KRepeat
// + k1 * KPack * NLane * KLane * NWave
// + n2 * KPack * NLane * KLane
// + k2 * KPack * NLane
// + n3 * KPack
// + k3;
// dst[outputIndex] = src[n * K + k];
// }
// }
// }
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
const int NRepeat = 1; const int NRepeat = 1;
const int KRepeat = 8; const int KRepeat = 8;
...@@ -230,7 +125,8 @@ using CDEElementOp = MultiplyMultiply; ...@@ -230,7 +125,8 @@ using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// clang-format off // clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
...@@ -349,10 +245,7 @@ int main(int argc, char* argv[]) ...@@ -349,10 +245,7 @@ int main(int argc, char* argv[])
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
// reinit2(a0_m_k.mData.data(), M, K);
// reinit2(b0_k_n.mData.data(), N, K);
preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data()); preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data());
// dump(b0_preshuffled.mData.data(), N, K);
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data()); // b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf.ToDevice(b0_preshuffled.mData.data()); b0_device_buf.ToDevice(b0_preshuffled.mData.data());
......
...@@ -281,8 +281,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -281,8 +281,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf0, ABlockBuffer& a_block_buf,
ABlockBuffer& a_block_buf1,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
...@@ -301,17 +300,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -301,17 +300,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<0>{}); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// // Local prefill 1 // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// // Global prefetch 2 // Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -322,12 +325,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -322,12 +325,21 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf0, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// main body // main body
...@@ -336,61 +348,13 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -336,61 +348,13 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t i = 0; index_t i = 0;
do do
{ {
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) { b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf0);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<0>{}); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
...@@ -399,12 +363,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -399,12 +363,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec;
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -428,75 +395,43 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -428,75 +395,43 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf0, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
}); });
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 2)); i += 1;
} while(i < (num_loop - 1));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec;
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -520,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -520,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
protected: protected:
using Base::a_thread_copy_; using Base::a_thread_copy_;
using Base::a_thread_desc_; using Base::a_thread_desc_;
// using Base::b_thread_copy_; using Base::b_thread_copy_;
using Base::b_thread_desc_; using Base::b_thread_desc_;
using Base::c_thread_desc_; using Base::c_thread_desc_;
}; };
......
...@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo ...@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
// if(arg.KBatch > 1) if(arg.KBatch > 1)
// { {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// { {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm, GridwiseGemm,
// true, true,
// InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, minimum_occupancy,
// TailNumber::Odd>; TailNumber::Odd>;
// Run(kernel); Run(kernel);
// } }
// else else
// { {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm, GridwiseGemm,
// true, true,
// InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy, minimum_occupancy,
// TailNumber::Even>; TailNumber::Even>;
// Run(kernel); Run(kernel);
// } }
// } }
// else else
// { {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// { {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm, GridwiseGemm,
// true, true,
// InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// minimum_occupancy, minimum_occupancy,
// TailNumber::Odd>; TailNumber::Odd>;
// Run(kernel); Run(kernel);
// } }
// else else
// { {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm, GridwiseGemm,
// true, true,
// InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
// minimum_occupancy, minimum_occupancy,
// TailNumber::Even>; TailNumber::Even>;
// Run(kernel); Run(kernel);
// } }
// } }
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment