Commit f60f9d59 authored by aska-0096's avatar aska-0096
Browse files

sanity pass, most tile size enabled. TODO: NWave!=4

parent 482ca684
...@@ -78,15 +78,18 @@ struct MultiplyMultiply ...@@ -78,15 +78,18 @@ struct MultiplyMultiply
} }
}; };
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) void preShuffleBuffer(const FP8* src,
FP8* dst,
int N,
int K,
int NRepeat,
int KRepeat,
int NWave,
int KLane,
int NLane,
int KPack)
{ {
const int NRepeat = 4; int K0 = K / (KRepeat * KLane * KPack);
const int KRepeat = 4;
const int NWave = 2;
const int KLane = 2;
const int NLane = 32;
const int KPack = 16;
int K0 = K / (KRepeat * KLane * KPack);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all // K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1 // lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk; int tempn, tempk;
...@@ -108,12 +111,30 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) ...@@ -108,12 +111,30 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst)
int n3 = tempn % NLane; int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0 + int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * K0 * NRepeat +
k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat + n1 * KPack * NLane * KLane * NWave * KRepeat * K0 +
n1 * KPack * NLane * KLane * NWave * KRepeat + k0 * KPack * NLane * KLane * NWave * KRepeat +
k2 * KPack * NLane * KLane * NWave // switch k1, k2 k2 * KPack * NLane * KLane * NWave + n2 * KPack * NLane * KLane +
+ n2 * KPack * NLane * KLane + k1 * KPack * NLane + n3 * KPack + k3; k1 * KPack * NLane + n3 * KPack + k3;
#if 0
int k1 = tempk / (KLane * KPack); //KRepeat
int n1 = tempn / (NLane * NWave); //NRepeat
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane; // NWave
int k2 = tempk / KPack; // KLane
int n3 = tempn % NLane; // NLane
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat +
k1 * KPack * NLane * KLane * NWave * NRepeat +
n1 * KPack * NLane * KLane * NWave +
n2 * KPack * NLane * KLane +
k2 * KPack * NLane +
n3 * KPack +
k3;
#endif
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
} }
} }
...@@ -124,7 +145,7 @@ using AElementOp = PassThrough; ...@@ -124,7 +145,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply; using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
...@@ -139,10 +160,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -139,10 +160,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
256, 256, 128, 32, 256, 128,
16, 16, 16, 16,
32, 32, 32, 32,
4, 4, 1, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
...@@ -245,6 +266,12 @@ int main(int argc, char* argv[]) ...@@ -245,6 +266,12 @@ int main(int argc, char* argv[])
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break; break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_m_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
...@@ -256,10 +283,8 @@ int main(int argc, char* argv[]) ...@@ -256,10 +283,8 @@ int main(int argc, char* argv[])
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data()); e_device_buf.ToDevice(e_m_n_device_result.mData.data());
...@@ -274,7 +299,23 @@ int main(int argc, char* argv[]) ...@@ -274,7 +299,23 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto preshuffle_params = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_k_n.mData.data(),
b0_preshuffled.mData.data(),
N,
K,
preshuffle_params[0],
preshuffle_params[1],
preshuffle_params[2],
preshuffle_params[3],
preshuffle_params[4],
preshuffle_params[5]);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(),
...@@ -300,7 +341,7 @@ int main(int argc, char* argv[]) ...@@ -300,7 +341,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, true, 50});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -315,7 +356,7 @@ int main(int argc, char* argv[]) ...@@ -315,7 +356,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
invoker.Run(argument, StreamConfig{nullptr, false, 0, 1, 1}); invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
...@@ -152,8 +152,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -152,8 +152,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{ {
ignore = num_loop;
return TailNumber::Full; return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
} }
__device__ static constexpr auto HotLoopScheduler() __device__ static constexpr auto HotLoopScheduler()
...@@ -342,8 +342,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -342,8 +342,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, b_blockwise_copy
Number<0>{}>(); .template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
...@@ -394,8 +395,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -394,8 +395,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, // b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0,
Number<1>{}>(); // k0, 0>,
b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
...@@ -435,7 +439,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -435,7 +439,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
} while(i < (num_loop - 2)); } while(i < (num_loop - 2));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Even)
{ {
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{}); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
...@@ -445,8 +449,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -445,8 +449,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>(); Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
...@@ -486,8 +490,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -486,8 +490,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>(); Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
...@@ -510,6 +514,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -510,6 +514,34 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
// latency // latency
// __builtin_amdgcn_sched_barrier(0); // __builtin_amdgcn_sched_barrier(0);
} }
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
} }
protected: protected:
......
...@@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator ...@@ -96,6 +96,51 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::array<int, 6> GetPreShuffleParameters() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -10,9 +10,10 @@ ...@@ -10,9 +10,10 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/flush_cache.hpp"
...@@ -69,55 +70,17 @@ template <typename ALayout, ...@@ -69,55 +70,17 @@ template <typename ALayout,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
: public DeviceGemmMultiD_Xdl_CShuffle_V3< : public DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
ALayout, BLayout,
BLayout, DsLayout,
DsLayout, CLayout,
CLayout, ADataType,
ADataType, BDataType,
BDataType, DsDataType,
DsDataType, CDataType,
CDataType, AElementwiseOperation,
GemmAccDataType, BElementwiseOperation,
CShuffleDataType, CElementwiseOperation>
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -176,6 +139,18 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -176,6 +139,18 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
std::array<int, 6> GetPreShuffleParameters() override
{
std::array<int, 6> preshuffle_params{NXdlPerWave,
GridwiseGemm::KRepeat,
GridwiseGemm::NWave,
GridwiseGemm::KLane,
GridwiseGemm::NLane,
GridwiseGemm::KPack};
return preshuffle_params;
}
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -278,21 +253,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -278,21 +253,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
GridwiseGemm, {
true, const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
InMemoryDataOperationEnum::AtomicAdd, GridwiseGemm,
minimum_occupancy>; true,
Run(kernel); InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
GridwiseGemm, {
true, const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
InMemoryDataOperationEnum::Set, GridwiseGemm,
minimum_occupancy>; true,
Run(kernel); InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -436,6 +439,57 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -436,6 +439,57 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
b_element_op, b_element_op,
c_element_op); c_element_op);
} }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp"
// #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...@@ -31,7 +31,7 @@ template <typename GridwiseGemm, ...@@ -31,7 +31,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1, index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full> TailNumber TailNum = TailNumber::Even>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
...@@ -142,7 +142,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -142,7 +142,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t NLane = NPerXdl; static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize); static_assert(NLane * NWave * KLane == BlockSize);
static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week"); // static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
...@@ -322,10 +322,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -322,10 +322,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{ {
constexpr index_t NkSwizzle = BlockSize * KPack; constexpr index_t NkSwizzleNumber = Number<BlockSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<NkSwizzle>{};
return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber), return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber),
make_tuple(K0 * NkSwizzle, NkSwizzleNumber, I1)); make_tuple(K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
...@@ -650,7 +649,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -650,7 +649,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
b_k_split_offset = k_id * karg.KRead; // KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0*N0
b_k_split_offset = k_id * karg.KRead * NLane * NWave * NXdlPerWave;
} }
if(k_id < karg.KBatch - 1) if(k_id < karg.KBatch - 1)
...@@ -1286,8 +1286,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1286,8 +1286,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave)); __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -1334,7 +1335,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1334,7 +1335,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, KRepeat, KPack * BlockSize>, Sequence<NXdlPerWave, KRepeat, KPack * BlockSize>,
Sequence<1, 1, BlockSize>, // BThreadClusterLengths, Sequence<1, 1, BlockSize>, // BThreadClusterLengths,
Sequence<0, 1, 2>, // BBlockTransferClusterArrangeOrder, Sequence<0, 1, 2>, // BBlockTransferClusterArrangeOrder,
BDataType, BDataType,
......
...@@ -20,7 +20,7 @@ namespace instance { ...@@ -20,7 +20,7 @@ namespace instance {
#if 0 #if 0
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
...@@ -33,7 +33,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -33,7 +33,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
...@@ -46,7 +46,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -46,7 +46,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
...@@ -59,7 +59,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -59,7 +59,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
...@@ -72,7 +72,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -72,7 +72,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
...@@ -85,7 +85,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -85,7 +85,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
...@@ -101,82 +101,88 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m ...@@ -101,82 +101,88 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>&
instances);
#endif #endif
template <typename ADataType, template <typename ADataType,
...@@ -185,31 +191,32 @@ template <typename ADataType, ...@@ -185,31 +191,32 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK< struct DeviceOperationInstanceFactory<
ALayout, ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle<
BLayout, ALayout,
Tuple<Row, Col>, BLayout,
CLayout, Tuple<Row, Col>,
ADataType, CLayout,
BDataType, ADataType,
Tuple<F32, F32>, BDataType,
CDataType, Tuple<F32, F32>,
ck::tensor_operation::element_wise::PassThrough, CDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>> ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
{ {
using DeviceOp = using DeviceOp =
DeviceGemmMultipleDSplitK<ALayout, DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
BLayout, BLayout,
Tuple<Row, Col>, Tuple<Row, Col>,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
Tuple<F32, F32>, Tuple<F32, F32>,
CDataType, CDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>; ck::tensor_operation::element_wise::MultiplyMultiply>;
static auto GetInstances() static auto GetInstances()
{ {
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -9,17 +9,17 @@ namespace device { ...@@ -9,17 +9,17 @@ namespace device {
namespace instance { namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances( void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances) MultiplyMultiply>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -24,6 +24,51 @@ ...@@ -24,6 +24,51 @@
namespace ck { namespace ck {
namespace profiler { namespace profiler {
template <typename InOutDataType>
void preShuffleBuffer(const InOutDataType* src,
InOutDataType* dst,
int N,
int K,
int NRepeat,
int KRepeat,
int NWave,
int KLane,
int NLane,
int KPack)
{
int K0 = K / (KRepeat * KLane * KPack);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (NRepeat * NLane * NWave);
int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KRepeat * KPack); // Klane
tempn = tempn % (NLane * NWave);
tempk = tempk % (KRepeat * KPack);
int n2 = tempn / NLane;
int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * K0 * NRepeat +
n1 * KPack * NLane * KLane * NWave * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * KRepeat +
k2 * KPack * NLane * KLane * NWave + n2 * KPack * NLane * KLane +
k1 * KPack * NLane + n3 * KPack + k3;
dst[outputIndex] = src[n * K + k];
}
}
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename ComputeDataType, typename ComputeDataType,
...@@ -71,6 +116,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -71,6 +116,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
...@@ -125,22 +172,21 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -125,22 +172,21 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
using DeviceOp = using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle<
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout, ALayout,
BLayout, BLayout,
ck::Tuple<D0Layout, D1Layout>, ck::Tuple<D0Layout, D1Layout>,
ELayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
ck::Tuple<D0DataType, D1DataType>, ck::Tuple<D0DataType, D1DataType>,
EDataType, EDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -188,8 +234,20 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -188,8 +234,20 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
// TODO: Shuffle the weight auto preshuffle_params = op_ptr->GetPreShuffleParameters();
// ...
preShuffleBuffer(b_k_n.mData.data(),
b_preshuffled.mData.data(),
N,
K,
preshuffle_params[0],
preshuffle_params[1],
preshuffle_params[2],
preshuffle_params[3],
preshuffle_params[4],
preshuffle_params[5]);
b_device_buf.ToDevice(b_preshuffled.mData.data());
std::vector<int> kbatch_list = {1, 2, 4, 8, 16}; std::vector<int> kbatch_list = {1, 2, 4, 8, 16};
...@@ -224,12 +282,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -224,12 +282,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification) if(do_verification)
{ {
......
...@@ -74,10 +74,10 @@ int profile_gemm_multiply_multiply_weight_preshuffle(int argc, char* argv[]) ...@@ -74,10 +74,10 @@ int profile_gemm_multiply_multiply_weight_preshuffle(int argc, char* argv[])
using F32 = float; using F32 = float;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16 = ck::half_t; // using F16 = ck::half_t;
using F8 = ck::f8_t; using F8 = ck::f8_t;
using I8 = int8_t; // using I8 = int8_t;
using I32 = int; // using I32 = int;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment