Commit 0dbe5370 authored by aska-0096's avatar aska-0096
Browse files

refine weight preshuffle format.

parent 72c1ddac
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
......@@ -39,7 +39,7 @@ using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
......@@ -97,63 +97,32 @@ struct MultiplyMultiply
}
};
void preShuffleBuffer(const FP8* src,
FP8* dst,
int N,
int K,
int NRepeat,
int KRepeat,
int NWave,
int KLane,
int NLane,
int KPack)
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
{
int K0 = K / (KRepeat * KLane * KPack);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk;
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int N0 = N / NLane;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (NRepeat * NLane * NWave);
int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KRepeat * KPack); // Klane
tempn = tempn % (NLane * NWave);
tempk = tempk % (KRepeat * KPack);
int n2 = tempn / NLane;
int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * K0 * NRepeat +
n1 * KPack * NLane * KLane * NWave * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * KRepeat +
k2 * KPack * NLane * KLane * NWave + n2 * KPack * NLane * KLane +
k1 * KPack * NLane + n3 * KPack + k3;
#if 0
int k1 = tempk / (KLane * KPack); //KRepeat
int n1 = tempn / (NLane * NWave); //NRepeat
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane; // NWave
int k2 = tempk / KPack; // KLane
int n3 = tempn % NLane; // NLane
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * NRepeat * KRepeat +
k1 * KPack * NLane * KLane * NWave * NRepeat +
n1 * KPack * NLane * KLane * NWave +
n2 * KPack * NLane * KLane +
k2 * KPack * NLane +
n3 * KPack +
k3;
#endif
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = k0 * KPack * NLane * KLane * N0 + n0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K + k];
}
}
......@@ -179,13 +148,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
256, 256, 128,
32, 256, 256,
16, 16,
32, 32,
8, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
32, 32,
1, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
......@@ -319,18 +288,9 @@ int main(int argc, char* argv[])
// do GEMM
auto device_op = DeviceOpInstance{};
auto preshuffle_params = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_k_n.mData.data(),
b0_preshuffled.mData.data(),
N,
K,
preshuffle_params[0],
preshuffle_params[1],
preshuffle_params[2],
preshuffle_params[3],
preshuffle_params[4],
preshuffle_params[5]);
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
......
......@@ -118,12 +118,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
MRepeat,
NRepeat,
KPack>;
using Base::A_K1;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::a_block_desc_m0_m1_m2_k;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
......@@ -136,8 +138,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
......@@ -145,6 +145,30 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
......@@ -275,11 +299,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf0,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
......@@ -305,12 +329,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
});
using mfma_input_type =
......@@ -332,11 +356,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
......@@ -357,15 +381,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
// b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0,
// k0, 0>,
b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
});
using mfma_input_type =
......@@ -387,11 +409,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf0,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
......@@ -411,12 +433,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
});
using mfma_input_type =
......@@ -436,11 +458,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
});
});
......@@ -452,12 +474,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
});
using mfma_input_type =
......@@ -483,12 +506,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<n0, k0, 0>,
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<k0, n0, 0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
});
using mfma_input_type =
......@@ -507,9 +530,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_desc_;
// MRepeat MWave MLane KRepeat KLane KPack
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
using Base::c_thread_desc_;
};
......
......@@ -113,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
}
__device__ static auto CalculateAThreadOriginDataIndex6D()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
......
......@@ -138,7 +138,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::array<int, 6> GetPreShuffleParameters() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device
......
......@@ -139,16 +139,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using Argument = typename GridwiseGemm::Argument;
std::array<int, 6> GetPreShuffleParameters() override
int GetPreShuffleParameters() override
{
std::array<int, 6> preshuffle_params{NXdlPerWave,
GridwiseGemm::KRepeat,
GridwiseGemm::NWave,
GridwiseGemm::KLane,
GridwiseGemm::NLane,
GridwiseGemm::KPack};
return preshuffle_params;
return NPerXDL;
}
// Invoker
......@@ -240,8 +233,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock/ BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
......@@ -307,21 +308,49 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
......
......@@ -141,8 +141,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize);
// static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static_assert(NWave * warpSize == BlockSize);
static constexpr auto MakeDsGridPointer()
{
......@@ -176,7 +175,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto CalculateBN0Shuffled(index_t N)
{
return math::integer_divide_ceil(N, NLane * NWave);
return math::integer_divide_ceil(N, NLane);
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
......@@ -322,9 +321,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<BlockSize * KPack>{};
return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber),
make_tuple(K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
return make_naive_tensor_descriptor(make_tuple(K0, N0/NWave, NWave, NkSwizzleNumber),
make_tuple(N0*NkSwizzleNumber, NWave*NkSwizzleNumber,NkSwizzleNumber, I1));
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
......@@ -649,8 +648,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
// KPack * NLane * KLane * NWave * KRepeat * K0* NRepeat * N0
b_k_split_offset = k_id * karg.KRead * NLane * NWave;
// KPack * NLane * KLane * N0 * K0
b_k_split_offset = k_id * karg.KRead * karg.N;
}
if(k_id < karg.KBatch - 1)
......@@ -1159,6 +1158,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
// check gridwise gemm pipeline
#if 0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
......@@ -1168,7 +1168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return false;
}
}
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
......@@ -1252,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
......@@ -1294,7 +1295,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// dummy
constexpr auto b_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1));
// A matrix blockwise copy
auto a_blockwise_copy =
......@@ -1335,17 +1338,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<NXdlPerWave, KRepeat, KPack * BlockSize>,
Sequence<1, 1, BlockSize>, // BThreadClusterLengths,
Sequence<0, 1, 2>, // BBlockTransferClusterArrangeOrder,
Sequence<KRepeat, NXdlPerWave, NWave, KPack * warpSize>,
Sequence<1, 1, NWave, warpSize>, // BThreadClusterLengths,
Sequence<0, 1, 2, 3>, // BBlockTransferClusterArrangeOrder,
BDataType,
LDSTypeB,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
Sequence<0, 1, 2, 3>, // BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
3,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
......@@ -1353,10 +1356,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
BThreadTransferSrcResetCoordinateAfterRun,
true,
2>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, 0, 0),
make_multi_index(0, n_block_data_idx_on_grid, 0, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment
......@@ -1367,7 +1370,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KRepeat, 0, 0, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
......
......@@ -29,40 +29,31 @@ void preShuffleBuffer(const InOutDataType* src,
InOutDataType* dst,
int N,
int K,
int NRepeat,
int KRepeat,
int NWave,
int KLane,
int NLane,
int KPack)
int NXdl)
{
int K0 = K / (KRepeat * KLane * KPack);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk;
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int N0 = N / NLane;
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (NRepeat * NLane * NWave);
int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KRepeat * KPack); // Klane
tempn = tempn % (NLane * NWave);
tempk = tempk % (KRepeat * KPack);
int n2 = tempn / NLane;
int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * K0 * NRepeat +
n1 * KPack * NLane * KLane * NWave * KRepeat * K0 +
k0 * KPack * NLane * KLane * NWave * KRepeat +
k2 * KPack * NLane * KLane * NWave + n2 * KPack * NLane * KLane +
k1 * KPack * NLane + n3 * KPack + k3;
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = k0 * KPack * NLane * KLane * N0 + n0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K + k];
}
......@@ -116,7 +107,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_preshuffled(
Tensor<BDataType> b_preshuffled_mfma16(
f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size
Tensor<BDataType> b_preshuffled_mfma32(
f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
......@@ -154,6 +147,9 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled_mfma16.mData.data(), N, K, 16);
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled_mfma32.mData.data(), N, K, 32);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
......@@ -166,12 +162,16 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_mfma16(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_mfma32(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf_mfma16.ToDevice(b_preshuffled_mfma16.mData.data());
b_device_buf_mfma32.ToDevice(b_preshuffled_mfma32.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
......@@ -234,20 +234,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
auto preshuffle_params = op_ptr->GetPreShuffleParameters();
preShuffleBuffer(b_k_n.mData.data(),
b_preshuffled.mData.data(),
N,
K,
preshuffle_params[0],
preshuffle_params[1],
preshuffle_params[2],
preshuffle_params[3],
preshuffle_params[4],
preshuffle_params[5]);
b_device_buf.ToDevice(b_preshuffled.mData.data());
int NPerXdl = op_ptr->GetPreShuffleParameters();
std::vector<int> kbatch_list = {1, 2, 4, 8};
......@@ -262,7 +249,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(NPerXdl == 16 ? b_device_buf_mfma16.GetDeviceBuffer()
: b_device_buf_mfma32.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
......@@ -298,8 +286,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
double rtol = 1e-3;
double atol = 5e-2;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
......
......@@ -50,7 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# endif()
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
# endif()
......@@ -137,7 +137,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
# endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment