Commit 0dbe5370 authored by aska-0096's avatar aska-0096
Browse files

refine weight preshuffle format.

parent 72c1ddac
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker) # target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
...@@ -39,7 +39,7 @@ using CShuffleDataType = F32; ...@@ -39,7 +39,7 @@ using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using D1DataType = F32; using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16; using EDataType = F16;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Col;
...@@ -97,63 +97,32 @@ struct MultiplyMultiply ...@@ -97,63 +97,32 @@ struct MultiplyMultiply
} }
}; };
void preShuffleBuffer(const FP8* src, void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
FP8* dst,
int N,
int K,
int NRepeat,
int KRepeat,
int NWave,
int KLane,
int NLane,
int KPack)
{ {
int K0 = K / (KRepeat * KLane * KPack); int KPack = 16;
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all int NLane = NXdl;
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1 int KLane = 64 / NLane;
int tempn, tempk;
int N0 = N / NLane;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
int n0 = n / (NRepeat * NLane * NWave); int n0 = n / NLane;
int k0 = k / (KRepeat * KLane * KPack); int n1 = n % NLane;
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack); int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int n1 = tempn / (NLane * NWave); int k1 = tempk / KPack;
int k1 = tempk / (KRepeat * KPack); // Klane int k2 = tempk % KPack;
tempn = tempn % (NLane * NWave);
tempk = tempk % (KRepeat * KPack); int outputIndex = k0 * KPack * NLane * KLane * N0 + n0 * KPack * NLane * KLane +
int n2 = tempn / NLane; k1 * KPack * NLane + n1 * KPack + k2;
int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * K0 * NRepeat +
n1 * KPack * NLane * KLane * NWave * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * KRepeat +
k2 * KPack * NLane * KLane * NWave + n2 * KPack * NLane * KLane +
k1 * KPack * NLane + n3 * KPack + k3;
#if 0
int k1 = tempk / (KLane * KPack); //KRepeat
int n1 = tempn / (NLane * NWave); //NRepeat
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane; // NWave
int k2 = tempk / KPack; // KLane
int n3 = tempn % NLane; // NLane
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat +
k1 * KPack * NLane * KLane * NWave * NRepeat +
n1 * KPack * NLane * KLane * NWave +
n2 * KPack * NLane * KLane +
k2 * KPack * NLane +
n3 * KPack +
k3;
#endif
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
} }
} }
...@@ -179,13 +148,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -179,13 +148,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
256, 256, 128, 32, 256, 256,
16, 16, 16, 16,
32, 32, 32, 32,
8, 2, 1, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...@@ -319,18 +288,9 @@ int main(int argc, char* argv[]) ...@@ -319,18 +288,9 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto preshuffle_params = device_op.GetPreShuffleParameters(); int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_k_n.mData.data(), preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
b0_preshuffled.mData.data(),
N,
K,
preshuffle_params[0],
preshuffle_params[1],
preshuffle_params[2],
preshuffle_params[3],
preshuffle_params[4],
preshuffle_params[5]);
b0_device_buf.ToDevice(b0_preshuffled.mData.data()); b0_device_buf.ToDevice(b0_preshuffled.mData.data());
......
...@@ -118,12 +118,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -118,12 +118,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack>; KPack>;
using Base::A_K1;
using Base::I0; using Base::I0;
using Base::I1; using Base::I1;
using Base::KRepeat; using Base::KRepeat;
using Base::xdlops_gemm; using Base::xdlops_gemm;
using typename Base::HotLoopInstList; using typename Base::HotLoopInstList;
using Base::a_block_desc_m0_m1_m2_k;
using Base::CalculateCThreadOriginDataIndex; using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D; using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
...@@ -136,8 +138,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -136,8 +138,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
...@@ -145,6 +145,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -145,6 +145,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static constexpr index_t PrefillStages = 1; static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1; static constexpr index_t GlobalBufferNum = 1;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{ {
return num_loop > PrefetchStages; return num_loop > PrefetchStages;
...@@ -275,11 +299,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -275,11 +299,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf0, a_block_buf0,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
}); });
...@@ -305,12 +329,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -305,12 +329,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>, .template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<0>{}>(); Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -332,11 +356,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -332,11 +356,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf1, a_block_buf1,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
}); });
...@@ -357,15 +381,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -357,15 +381,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
// b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0,
// k0, 0>,
b_blockwise_copy b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>, .template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<1>{}>(); Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -387,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -387,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf0, a_block_buf0,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
}); });
...@@ -411,12 +433,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -411,12 +433,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>, b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<0>{}>(); Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -436,11 +458,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -436,11 +458,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf1, a_block_buf1,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
}); });
...@@ -452,12 +474,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -452,12 +474,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>, b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<1>{}>(); Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -483,12 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -483,12 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>, b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<0>{}>(); Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, I0, k0, I0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -507,9 +530,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -507,9 +530,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
} }
protected: protected:
using Base::a_thread_copy_; // MRepeat MWave MLane KRepeat KLane KPack
using Base::a_thread_desc_; // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
using Base::b_thread_desc_; static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
using Base::c_thread_desc_; using Base::c_thread_desc_;
}; };
......
...@@ -113,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -113,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
} }
__device__ static auto CalculateAThreadOriginDataIndex6D()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
}
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
{ {
const auto wave_idx = GetWaveIdx(); const auto wave_idx = GetWaveIdx();
......
...@@ -138,7 +138,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator ...@@ -138,7 +138,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::array<int, 6> GetPreShuffleParameters() = 0; virtual int GetPreShuffleParameters() = 0;
}; };
} // namespace device } // namespace device
......
...@@ -139,16 +139,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -139,16 +139,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
std::array<int, 6> GetPreShuffleParameters() override int GetPreShuffleParameters() override
{ {
std::array<int, 6> preshuffle_params{NXdlPerWave, return NPerXDL;
GridwiseGemm::KRepeat,
GridwiseGemm::NWave,
GridwiseGemm::KLane,
GridwiseGemm::NLane,
GridwiseGemm::KPack};
return preshuffle_params;
} }
// Invoker // Invoker
...@@ -240,8 +233,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -240,8 +233,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy = []() {
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock/ BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right // has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
...@@ -306,12 +307,38 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -306,12 +307,38 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
else else
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
false, false,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>; minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel); Run(kernel);
} }
else else
...@@ -320,10 +347,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -320,10 +347,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
GridwiseGemm, GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy,
TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
}
return ave_time; return ave_time;
} }
......
...@@ -141,8 +141,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -141,8 +141,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl; static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize); static_assert(NWave * warpSize == BlockSize);
// static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
...@@ -176,7 +175,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -176,7 +175,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto CalculateBN0Shuffled(index_t N) __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
{ {
return math::integer_divide_ceil(N, NLane * NWave); return math::integer_divide_ceil(N, NLane);
} }
__host__ __device__ static auto CalculateBK0Shuffled(index_t K) __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{ {
...@@ -322,9 +321,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -322,9 +321,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{ {
constexpr index_t NkSwizzleNumber = Number<BlockSize * KPack>{}; constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber), return make_naive_tensor_descriptor(make_tuple(K0, N0/NWave, NWave, NkSwizzleNumber),
make_tuple(K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); make_tuple(N0*NkSwizzleNumber, NWave*NkSwizzleNumber,NkSwizzleNumber, I1));
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
...@@ -649,8 +648,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -649,8 +648,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
// KPack * NLane * KLane * NWave * KRepeat * K0* NRepeat * N0 // KPack * NLane * KLane * N0 * K0
b_k_split_offset = k_id * karg.KRead * NLane * NWave; b_k_split_offset = k_id * karg.KRead * karg.N;
} }
if(k_id < karg.KBatch - 1) if(k_id < karg.KBatch - 1)
...@@ -1159,6 +1158,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1159,6 +1158,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
#if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
...@@ -1168,7 +1168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1168,7 +1168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return false; return false;
} }
} }
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true; return true;
} }
...@@ -1252,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1252,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
...@@ -1294,7 +1295,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1294,7 +1295,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // dummy
constexpr auto b_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1));
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -1335,17 +1338,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1335,17 +1338,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<NXdlPerWave, KRepeat, KPack * BlockSize>, Sequence<KRepeat, NXdlPerWave, NWave, KPack * warpSize>,
Sequence<1, 1, BlockSize>, // BThreadClusterLengths, Sequence<1, 1, NWave, warpSize>, // BThreadClusterLengths,
Sequence<0, 1, 2>, // BBlockTransferClusterArrangeOrder, Sequence<0, 1, 2, 3>, // BBlockTransferClusterArrangeOrder,
BDataType, BDataType,
LDSTypeB, LDSTypeB,
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, // BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorDim, 3,
2, 3,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
1, 1,
...@@ -1353,10 +1356,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1353,10 +1356,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
2>(b_grid_desc_bpreshuffled, 2>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, 0, 0), make_multi_index(0, n_block_data_idx_on_grid, 0, 0),
b_element_op, b_element_op,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -1367,7 +1370,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1367,7 +1370,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KRepeat, 0, 0, 0);
// Blockwise GEMM pipeline // Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>); static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
......
...@@ -29,40 +29,31 @@ void preShuffleBuffer(const InOutDataType* src, ...@@ -29,40 +29,31 @@ void preShuffleBuffer(const InOutDataType* src,
InOutDataType* dst, InOutDataType* dst,
int N, int N,
int K, int K,
int NRepeat, int NXdl)
int KRepeat,
int NWave,
int KLane,
int NLane,
int KPack)
{ {
int K0 = K / (KRepeat * KLane * KPack); int KPack = 16;
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all int NLane = NXdl;
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1 int KLane = 64 / NLane;
int tempn, tempk;
int N0 = N / NLane;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
int n0 = n / (NRepeat * NLane * NWave); int n0 = n / NLane;
int k0 = k / (KRepeat * KLane * KPack); int n1 = n % NLane;
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack); int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int n1 = tempn / (NLane * NWave); int k1 = tempk / KPack;
int k1 = tempk / (KRepeat * KPack); // Klane int k2 = tempk % KPack;
tempn = tempn % (NLane * NWave);
tempk = tempk % (KRepeat * KPack); int outputIndex = k0 * KPack * NLane * KLane * N0 + n0 * KPack * NLane * KLane +
int n2 = tempn / NLane; k1 * KPack * NLane + n1 * KPack + k2;
int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * K0 * NRepeat +
n1 * KPack * NLane * KLane * NWave * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * KRepeat +
k2 * KPack * NLane * KLane * NWave + n2 * KPack * NLane * KLane +
k1 * KPack * NLane + n3 * KPack + k3;
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
} }
...@@ -116,7 +107,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -116,7 +107,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_preshuffled( Tensor<BDataType> b_preshuffled_mfma16(
f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size
Tensor<BDataType> b_preshuffled_mfma32(
f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
...@@ -154,6 +147,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -154,6 +147,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled_mfma16.mData.data(), N, K, 16);
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled_mfma32.mData.data(), N, K, 32);
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
...@@ -166,12 +162,16 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -166,12 +162,16 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_mfma16(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_mfma32(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf_mfma16.ToDevice(b_preshuffled_mfma16.mData.data());
b_device_buf_mfma32.ToDevice(b_preshuffled_mfma32.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
...@@ -234,20 +234,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -234,20 +234,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto preshuffle_params = op_ptr->GetPreShuffleParameters(); int NPerXdl = op_ptr->GetPreShuffleParameters();
preShuffleBuffer(b_k_n.mData.data(),
b_preshuffled.mData.data(),
N,
K,
preshuffle_params[0],
preshuffle_params[1],
preshuffle_params[2],
preshuffle_params[3],
preshuffle_params[4],
preshuffle_params[5]);
b_device_buf.ToDevice(b_preshuffled.mData.data());
std::vector<int> kbatch_list = {1, 2, 4, 8}; std::vector<int> kbatch_list = {1, 2, 4, 8};
...@@ -262,7 +249,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -262,7 +249,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(NPerXdl == 16 ? b_device_buf_mfma16.GetDeviceBuffer()
: b_device_buf_mfma32.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(), std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()}, d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
...@@ -298,8 +286,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, ...@@ -298,8 +286,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
is_same_v<EDataType, int8_t>)) is_same_v<EDataType, int8_t>))
{ {
std::string msg = "Error: Incorrect results!"; std::string msg = "Error: Incorrect results!";
double rtol = 1e-1; double rtol = 1e-3;
double atol = 1e-1; double atol = 5e-2;
pass = pass & ck::utils::check_err( pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
} }
......
...@@ -50,7 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -50,7 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# endif() # endif()
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") # if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp) list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
# endif() # endif()
...@@ -137,7 +137,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -137,7 +137,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") # if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
# endif() # endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment