Commit ed2d5e40 authored by Jing Zhang's avatar Jing Zhang
Browse files

enabled splitk_gemm_multi_d

parent 0fff2a66
...@@ -69,18 +69,19 @@ using AElementOp = PassThrough; ...@@ -69,18 +69,19 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply; using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off // clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| < Row, Col, DsLayout, ELayout,
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| AElementOp, BElementOp, CDEElementOp, GemmSpec,
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>| 256, 256, 128, 64,
///###### RRR 16, 16, 32, 32, 4, 2,
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
///###### RCR S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -99,6 +100,8 @@ int main(int argc, char* argv[]) ...@@ -99,6 +100,8 @@ int main(int argc, char* argv[])
ck::index_t StrideD = 0; ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
ck::index_t KBatch = 1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -109,7 +112,7 @@ int main(int argc, char* argv[]) ...@@ -109,7 +112,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 12)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -123,13 +126,15 @@ int main(int argc, char* argv[]) ...@@ -123,13 +126,15 @@ int main(int argc, char* argv[])
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]); StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]); StrideE = std::stoi(argv[10]);
KBatch = std::stoi(argv[11]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n");
exit(0); exit(0);
} }
...@@ -212,6 +217,7 @@ int main(int argc, char* argv[]) ...@@ -212,6 +217,7 @@ int main(int argc, char* argv[])
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0}, std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE, StrideE,
KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
...@@ -236,10 +242,12 @@ int main(int argc, char* argv[]) ...@@ -236,10 +242,12 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
invoker.Run(argument, StreamConfig{nullptr, false});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
Tensor<CShuffleDataType> c_m_n({M, N}); Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
......
...@@ -60,7 +60,7 @@ static constexpr ck::index_t Scale_Block_N = 128; ...@@ -60,7 +60,7 @@ static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128; static constexpr ck::index_t Scale_Block_K = 128;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// clang-format off // clang-format off
<Row, Col, DsLayout, ELayout, <Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
......
...@@ -53,6 +53,51 @@ struct DeviceGemmMultipleD : public BaseOperator ...@@ -53,6 +53,51 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -69,7 +69,7 @@ template <typename ALayout, ...@@ -69,7 +69,7 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
CLayout, CLayout,
...@@ -192,15 +192,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -192,15 +192,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString( hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid, hipMemsetAsync(arg_.p_c_grid,
0, 0,
arg_.M * arg_.N * sizeof(CDataType), arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_)); stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -234,6 +231,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -234,6 +231,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -246,6 +253,119 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -246,6 +253,119 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven // Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
...@@ -361,7 +481,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -361,7 +481,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -387,6 +530,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -387,6 +530,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
else else
{ {
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -416,6 +583,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -416,6 +583,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -479,6 +656,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -479,6 +656,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<index_t, NumDTensor> StrideDs, std::array<index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
...@@ -494,7 +672,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -494,7 +672,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
1, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op};
...@@ -514,6 +692,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -514,6 +692,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB, index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs, std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC, index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
...@@ -529,7 +708,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -529,7 +708,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB, StrideB,
StrideDs, StrideDs,
StrideC, StrideC,
1, KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
......
...@@ -558,7 +558,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -558,7 +558,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
const auto kernel = const auto kernel =
......
...@@ -38,8 +38,7 @@ __global__ void ...@@ -38,8 +38,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
...@@ -56,7 +55,7 @@ __global__ void ...@@ -56,7 +55,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -71,8 +70,7 @@ __global__ void ...@@ -71,8 +70,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -93,7 +91,7 @@ __global__ void ...@@ -93,7 +91,7 @@ __global__ void
karg.c_element_op); karg.c_element_op);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx9__))
} }
template <typename ALayout, template <typename ALayout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment