Commit 22fe522d authored by aska-0096's avatar aska-0096
Browse files

optimize software pipeline

parent 9dd74e0d
...@@ -103,10 +103,10 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) ...@@ -103,10 +103,10 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
int NLane = NXdl; int NLane = NXdl;
int KLane = 64 / NLane; int KLane = 64 / NLane;
int N0 = N / NLane; int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack // K -> K0 KLane KPack
// N -> N0 NLane // N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack // N, K -> N0 K0 KLane NLane KPack
int tempk; int tempk;
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
...@@ -120,7 +120,7 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) ...@@ -120,7 +120,7 @@ void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
int k1 = tempk / KPack; int k1 = tempk / KPack;
int k2 = tempk % KPack; int k2 = tempk % KPack;
int outputIndex = k0 * KPack * NLane * KLane * N0 + n0 * KPack * NLane * KLane + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2; k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
...@@ -148,14 +148,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -148,14 +148,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
32, 256, 256, 32, 128, 256,
16, 16, 16, 16,
32, 32, 32, 32,
1, 2, 1, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...@@ -366,7 +366,10 @@ int main(int argc, char* argv[]) ...@@ -366,7 +366,10 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 : 1; return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
} }
return 0; return 0;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
namespace ck {
enum struct BlockGemmPipelineVersion
{
v1, // Single lds buffer
v2, // Double lds buffer
};
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmBPreshufflePipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck
...@@ -139,10 +139,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -139,10 +139,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
int GetPreShuffleParameters() override int GetPreShuffleParameters() override { return NPerXDL; }
{
return NPerXDL;
}
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -233,16 +230,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -233,16 +230,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
}; };
constexpr index_t minimum_occupancy = []() { constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) 4 * (1 + GridwiseGemm::NWave);
{ constexpr auto estimated_reg_b =
return (MPerBlock * NPerBlock/ BlockSize <= 128) ? 2 : 1; NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
} constexpr auto estimated_reg_c =
else MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
{ constexpr auto estimated_reg_total =
return 1; estimated_reg_a + estimated_reg_b + estimated_reg_c;
}
}(); constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right // has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
...@@ -250,7 +247,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -250,7 +247,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
// Tail number always full // Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
...@@ -299,16 +296,72 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -299,16 +296,72 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
} }
} }
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else else
{ {
throw std::runtime_error("todo: only v3 support now"); throw std::runtime_error("todo: only v1 & v2 support now");
} }
} }
#if 0
else else
{ {
if(arg.KBatch > 1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
...@@ -328,10 +381,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -328,10 +381,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
} }
} }
else else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
GridwiseGemm, GridwiseGemm,
...@@ -351,8 +404,67 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -351,8 +404,67 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
TailNumber::Even>; TailNumber::Even>;
Run(kernel); Run(kernel);
} }
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
throw std::runtime_error("todo: only v3 support now");
} }
} }
#endif
return ave_time; return ave_time;
} }
...@@ -490,11 +602,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ...@@ -490,11 +602,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
{BlockGemmPipelineScheduler::Interwave, "Interwave"}}; {BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{ std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}};
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off // clang-format off
str << "DeviceGemmXdlUniversal" str << "DeviceGemmXdlUniversal"
......
...@@ -25,20 +25,16 @@ namespace ck { ...@@ -25,20 +25,16 @@ namespace ck {
namespace profiler { namespace profiler {
template <typename InOutDataType> template <typename InOutDataType>
void preShuffleBuffer(const InOutDataType* src, void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K, int NXdl)
InOutDataType* dst,
int N,
int K,
int NXdl)
{ {
int KPack = 16; int KPack = 16;
int NLane = NXdl; int NLane = NXdl;
int KLane = 64 / NLane; int KLane = 64 / NLane;
int N0 = N / NLane; int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack // K -> K0 KLane KPack
// N -> N0 NLane // N -> N0 NLane
// N, K -> K0 N0 KLane NLane KPack // N, K -> N0 K0 KLane NLane KPack
int tempk; int tempk;
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
...@@ -52,7 +48,7 @@ void preShuffleBuffer(const InOutDataType* src, ...@@ -52,7 +48,7 @@ void preShuffleBuffer(const InOutDataType* src,
int k1 = tempk / KPack; int k1 = tempk / KPack;
int k2 = tempk % KPack; int k2 = tempk % KPack;
int outputIndex = k0 * KPack * NLane * KLane * N0 + n0 * KPack * NLane * KLane + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2; k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment