Commit 2732d06c authored by rocking's avatar rocking
Browse files

Merge commit '75891161' into gemm_layernorm_welford

parents dd0255ba 75891161
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// DataType
using ADataType = BF16;
using BDataType = BF16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using ReduceAccDataType = F32;
using R0DataType = F32;
using RsDataType = ck::Tuple<R0DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<ALayout, // ALayout
BLayout, // BLayout
ELayout, // ELayout
ADataType, // ADataType
BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
8, // ABlockTransfer SrcScalarPerVector
8, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
8, // BBlockTransfer SrcScalarPerVector
8, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
ReduceAccDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1152;
ck::index_t K = 256;
ck::index_t StrideA = 256;
ck::index_t StrideB = 256;
ck::index_t StrideE = 1152;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
return run_gemm_reduce_max_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "gemm_reduce_xdl_common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType // DataType
using ADataType = F16; using ADataType = F16;
...@@ -45,7 +24,6 @@ using BLayout = Col; ...@@ -45,7 +24,6 @@ using BLayout = Col;
using ELayout = Row; using ELayout = Row;
// Elementwise op // Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = PassThrough; using CDEElementOp = PassThrough;
...@@ -54,7 +32,6 @@ using RsElementOp = ck::Tuple<PassThrough>; ...@@ -54,7 +32,6 @@ using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp // ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>; using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp = using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>; ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
...@@ -62,56 +39,72 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -62,56 +39,72 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer| <ALayout, // ALayout
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| BLayout, // BLayout
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| ELayout, // ELayout
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | ADataType, // ADataType
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
8, // ABlockTransfer SrcScalarPerVector
8, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
8, // BBlockTransfer SrcScalarPerVector
8, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
EDataType, ReduceAccDataType,
GemmAccDataType, GemmAccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp>; CDEElementOp>;
template <typename ADataType, typename BDataType, typename EDataType, typename R0DataType> int main(int argc, char* argv[])
void DumpPerf(float ave_time, int M, int N, int K)
{ {
std::size_t flop = std::size_t(2) * M * N * K; bool do_verification = true;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + int init_method = 1;
sizeof(EDataType) * M * N + sizeof(R0DataType) * M; bool time_kernel = true;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
int main() // GEMM shape
{
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t K = 1024; ck::index_t K = 1024;
...@@ -120,108 +113,55 @@ int main() ...@@ -120,108 +113,55 @@ int main()
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024; ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); if(argc == 1)
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{};
// Prepare GEMM, max
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{ {
throw std::runtime_error("wrong! this device_op instance does not support this problem"); // do nothing
} }
else if(argc == 4)
// [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool do_verification = true;
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
Tensor<EDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
auto reduce0_op = RsThreadReduceOp{}[I0];
for(int m = 0; m < M; ++m)
{ {
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>(); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
for(int n = 0; n < N; ++n) time_kernel = std::stoi(argv[3]);
{
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n));
reduce0_op(reduce0_acc, e_val);
};
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
} }
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
e_device_buf.FromDevice(e_m_n.mData.data()); M = std::stoi(argv[4]);
r0_device_buf.FromDevice(r0_m.mData.data()); N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
pass = ck::utils::check_err( StrideA = std::stoi(argv[7]);
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results e", 1e-2, 1e-2); StrideB = std::stoi(argv[8]);
pass &= ck::utils::check_err( StrideE = std::stoi(argv[9]);
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
} }
else
bool time_kernel = true;
if(time_kernel)
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::cout << "arg1: verification (0=no, 1=yes)\n"
DumpPerf<ADataType, BDataType, EDataType, R0DataType>(ave_time, M, N, K); << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
} }
return pass ? 0 : 1; return run_gemm_reduce_max_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// DataType
using ADataType = F32;
using BDataType = F32;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ReduceAccDataType = F32;
using R0DataType = F32;
using RsDataType = ck::Tuple<R0DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<ALayout, // ALayout
BLayout, // BLayout
ELayout, // ELayout
ADataType, // ADataType
BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
4, // ABlockTransfer SrcScalarPerVector
4, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
4, // BBlockTransfer SrcScalarPerVector
4, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
ReduceAccDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: Measure kernel execution time (1=ON, 0=Off)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
exit(0);
}
return run_gemm_reduce_max_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
using ADataType = INT4;
using ADataKernelType = INT8;
using BDataType = INT4;
using BDataKernelType = INT8;
using GemmAccDataType = INT32;
using CShuffleDataType = INT32;
using DsDataType = ck::Tuple<>;
using EDataType = INT4;
using EDataKernelType = INT8;
using ReduceAccDataType = INT32;
using R0DataType = INT32;
using RsDataType = ck::Tuple<R0DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<ALayout, // ALayout
BLayout, // BLayout
ELayout, // ELayout
ADataKernelType, // ADataType
BDataKernelType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataKernelType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
16, // ABlockTransfer SrcScalarPerVector
16, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
16, // BBlockTransfer SrcScalarPerVector
16, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
ReduceAccDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1152;
ck::index_t K = 256;
ck::index_t StrideA = 256;
ck::index_t StrideB = 256;
ck::index_t StrideE = 1152;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
return run_gemm_reduce_max_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance,
ADataKernelType,
BDataKernelType,
EDataKernelType>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
using ADataType = INT8;
using BDataType = INT8;
using GemmAccDataType = INT32;
using CShuffleDataType = INT32;
using DsDataType = ck::Tuple<>;
using EDataType = INT8;
using ReduceAccDataType = INT32;
using R0DataType = INT32;
using RsDataType = ck::Tuple<R0DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<ALayout, // ALayout
BLayout, // BLayout
ELayout, // ELayout
ADataType, // ADataType
BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
16, // ABlockTransfer SrcScalarPerVector
16, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
16, // BBlockTransfer SrcScalarPerVector
16, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
ReduceAccDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1152;
ck::index_t K = 512;
ck::index_t StrideA = 512;
ck::index_t StrideB = 512;
ck::index_t StrideE = 1152;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
return run_gemm_reduce_max_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// DataType
using ADataType = BF16;
using BDataType = BF16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using ReduceAccDataType = F32;
using R0DataType = F32;
using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough, Square>;
using RsElementOp = ck::Tuple<Div, Div>;
// ReduceOp
using R0ThreadReduceOp = ck::reduce::Add;
using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<ALayout, // ALayout
BLayout, // BLayout
ELayout, // ELayout
ADataType, // ADataType
BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
8, // ABlockTransfer SrcScalarPerVector
8, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
8, // BBlockTransfer SrcScalarPerVector
8, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
ReduceAccDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1152;
ck::index_t K = 192;
ck::index_t StrideA = 192;
ck::index_t StrideB = 192;
ck::index_t StrideE = 1152;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
return !run_gemm_reduce_mean_meansquare_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
R1DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "gemm_reduce_xdl_common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType // DataType
using ADataType = F16; using ADataType = F16;
...@@ -45,7 +25,6 @@ using BLayout = Col; ...@@ -45,7 +25,6 @@ using BLayout = Col;
using ELayout = Row; using ELayout = Row;
// Elementwise op // Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare; using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -67,61 +46,71 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -67,61 +46,71 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer| <ALayout, // ALayout
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| BLayout, // BLayout
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| ELayout, // ELayout
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | ADataType, // ADataType
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
8, // ABlockTransfer SrcScalarPerVector
8, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
8, // BBlockTransfer SrcScalarPerVector
8, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
EDataType, ReduceAccDataType,
GemmAccDataType, GemmAccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp>; CDEElementOp>;
template <typename ADataType, int main(int argc, char* argv[])
typename BDataType,
typename EDataType,
typename R0DataType,
typename R1DataType>
void DumpPerf(float ave_time, int M, int N, int K)
{ {
std::size_t flop = std::size_t(2) * M * N * K; bool do_verification = true;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + int init_method = 1;
sizeof(EDataType) * M * N + sizeof(R0DataType) * M + bool time_kernel = true;
sizeof(R1DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
int main() // GEMM shape
{
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t K = 1024; ck::index_t K = 1024;
...@@ -130,125 +119,56 @@ int main() ...@@ -130,125 +119,56 @@ int main()
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024; ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); if(argc == 1)
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> r1_m(f_host_tensor_descriptor1d(M, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
// Prepare GEMM, mean, mean_square
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{ {
throw std::runtime_error("wrong! this device_op instance does not support this problem"); // do nothing
} }
else if(argc == 4)
// init reducetion buffer to 0
r0_device_buf.SetZero();
r1_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
bool do_verification = true;
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
auto I1 = ck::Number<1>{};
Tensor<EDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
Tensor<R1DataType> r1_m_host(r1_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
auto reduce0_op = R0ThreadReduceOp{};
auto reduce1_op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{ {
ReduceAccDataType square_e_val; do_verification = std::stoi(argv[1]);
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n)); init_method = std::stoi(argv[2]);
qs_element_op[I1](square_e_val, e_val); time_kernel = std::stoi(argv[3]);
reduce0_op(reduce0_acc, e_val);
reduce1_op(reduce1_acc, square_e_val);
}
rs_element_op[I0](reduce0_acc, reduce0_acc);
rs_element_op[I1](reduce1_acc, reduce1_acc);
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
r1_m_host(m) = ck::type_convert<R1DataType>(reduce1_acc);
} }
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
e_device_buf.FromDevice(e_m_n.mData.data()); M = std::stoi(argv[4]);
r0_device_buf.FromDevice(r0_m.mData.data()); N = std::stoi(argv[5]);
r1_device_buf.FromDevice(r1_m.mData.data()); K = std::stoi(argv[6]);
pass = ck::utils::check_err( StrideA = std::stoi(argv[7]);
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results c", 1e-2, 1e-2); StrideB = std::stoi(argv[8]);
pass &= ck::utils::check_err( StrideE = std::stoi(argv[9]);
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
} }
else
bool time_kernel = true;
if(time_kernel)
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::cout << "arg1: verification (0=no, 1=yes)\n"
DumpPerf<ADataType, BDataType, EDataType, R0DataType, R1DataType>(ave_time, M, N, K); << " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
} }
return pass ? 0 : 1; return !run_gemm_reduce_mean_meansquare_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
R1DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_reduce_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
// DataType
using ADataType = F32;
using BDataType = F32;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ReduceAccDataType = F32;
using R0DataType = F32;
using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough, Square>;
using RsElementOp = ck::Tuple<Div, Div>;
// ReduceOp
using R0ThreadReduceOp = ck::reduce::Add;
using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<ALayout, // ALayout
BLayout, // BLayout
ELayout, // ELayout
ADataType, // ADataType
BDataType, // BDataType
GemmAccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType
EDataType, // EDataType
ReduceAccDataType, // ReduceAccDataType
RsDataType, // RsDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CDEElementOp, // CDE ElementwiseOperation
QsElementOp, // Qs Elementwise Operation
RsElementOp, // Rs Elementwise Operation
RsThreadReduceOp, // Thread Reduce Operation
RsGlobalReduceOp, // Global Reduce Operation
GemmDefault, // GEMM Specialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
4, // ABlockTransfer SrcScalarPerVector
4, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
4, // BBlockTransfer SrcScalarPerVector
4, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock
4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock
1>; // RThread DstScalarPerVector _MPerBlock
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
ReduceAccDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideE = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< " arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< " arg3: Measure kernel execution time (1=ON, 0=Off)\n"
<< " arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
return !run_gemm_reduce_mean_meansquare_xdl<ADataType,
BDataType,
EDataType,
R0DataType,
R1DataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp,
QsElementOp,
RsElementOp,
RsThreadReduceOp,
ReduceAccDataType,
DeviceOpInstance,
ReferenceGemmInstance>(
M, N, K, StrideA, StrideB, StrideE, do_verification, init_method, time_kernel);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/stream_config.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using F64 = double;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using INT4 = ck::int4_t;
#endif
using INT8 = std::int8_t;
using INT32 = std::int32_t;
template <typename ADataType, typename BDataType, typename EDataType, typename R0DataType>
void DumpGemmReduceMaxPerf(float ave_time, int M, int N, int K)
{
using namespace ck::literals;
std::size_t flop = 2_uz * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N + sizeof(R0DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
template <typename ADataType,
typename BDataType,
typename EDataType,
typename R0DataType,
typename R1DataType>
void DumpGemmReduceMeanSquareMeanPerf(float ave_time, int M, int N, int K)
{
using namespace ck::literals;
std::size_t flop = 2_uz * M * N * K + M * (3_uz * N + 2_uz);
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
template <typename ADataType,
typename BDataType,
typename EDataType,
typename R0DataType,
typename ALayout,
typename BLayout,
typename ELayout,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
typename QsElementOp,
typename RsElementOp,
typename RsThreadReduceOp,
typename ReduceAccDataType,
typename DeviceOpInstance,
typename ReferenceGemmInstance,
typename ADataKernelType = ADataType,
typename BDataKernelType = BDataType,
typename EDataKernelType = EDataType>
auto run_gemm_reduce_max_xdl(ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideE,
bool do_verification,
int init_method,
bool time_kernel)
{
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(ADataKernelType));
static_assert(sizeof(BDataType) == sizeof(BDataKernelType));
static_assert(sizeof(EDataType) == sizeof(EDataKernelType));
#endif
using namespace ck::literals;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor({len}, {stride});
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataKernelType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
break;
}
DeviceMem a_device_buf(sizeof(ADataKernelType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataKernelType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataKernelType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<ADataKernelType> a_m_k_converted = a_m_k.template CopyAsType<ADataKernelType>();
Tensor<BDataKernelType> b_k_n_converted = b_k_n.template CopyAsType<BDataKernelType>();
a_device_buf.ToDevice(a_m_k_converted.mData.data());
b_device_buf.ToDevice(b_k_n_converted.mData.data());
}
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{};
// Prepare GEMM, max
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
// [CAUTION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
Tensor<ReduceAccDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
auto reduce0_op = RsThreadReduceOp{}[I0];
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.template GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto e_val = e_m_n_host(m, n);
reduce0_op(reduce0_acc, e_val);
};
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
}
e_device_buf.FromDevice(e_m_n.mData.data());
Tensor<EDataType> e_m_n_host_converted(e_m_n_host);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<EDataType> e_m_n_device_converted(e_m_n);
pass = ck::utils::check_err(e_m_n_device_converted.mData,
e_m_n_host_converted.mData,
"Error: Incorrect results c",
1e-2,
1e-2);
}
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host_converted.mData, "Error: Incorrect results c", 1e-2, 1e-2);
}
r0_device_buf.FromDevice(r0_m.mData.data());
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
if(pass)
{
std::cout << "Success!" << std::endl;
}
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
DumpGemmReduceMaxPerf<ADataType, BDataType, EDataType, R0DataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
template <typename ADataType,
typename BDataType,
typename EDataType,
typename R0DataType,
typename R1DataType,
typename ALayout,
typename BLayout,
typename ELayout,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
typename QsElementOp,
typename RsElementOp,
typename RsThreadReduceOp,
typename ReduceAccDataType,
typename DeviceOpInstance,
typename ReferenceGemmInstance,
typename ADataKernelType = ADataType,
typename BDataKernelType = BDataType,
typename EDataKernelType = EDataType>
bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideE,
bool do_verification,
int init_method,
bool time_kernel)
{
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(ADataKernelType));
static_assert(sizeof(BDataType) == sizeof(BDataKernelType));
static_assert(sizeof(EDataType) == sizeof(EDataKernelType));
#endif
using namespace ck::literals;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor({len}, {stride});
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataKernelType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> r1_m(f_host_tensor_descriptor1d(M, 1));
switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
break;
}
DeviceMem a_device_buf(sizeof(ADataKernelType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataKernelType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataKernelType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<ADataKernelType> a_m_k_converted = a_m_k.template CopyAsType<ADataKernelType>();
Tensor<BDataKernelType> b_k_n_converted = b_k_n.template CopyAsType<BDataKernelType>();
a_device_buf.ToDevice(a_m_k_converted.mData.data());
b_device_buf.ToDevice(b_k_n_converted.mData.data());
}
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
// Prepare GEMM, mean, mean_square
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
// init reducetion buffer to 0
r0_device_buf.SetZero();
r1_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
auto I1 = ck::Number<1>{};
Tensor<ReduceAccDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
Tensor<R1DataType> r1_m_host(r1_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
auto reduce0_op = RsThreadReduceOp{}[I0];
auto reduce1_op = RsThreadReduceOp{}[I1];
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.template GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.template GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType square_e_val;
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n));
qs_element_op[I1](square_e_val, e_val);
reduce0_op(reduce0_acc, e_val);
reduce1_op(reduce1_acc, square_e_val);
}
rs_element_op[I0](reduce0_acc, reduce0_acc);
rs_element_op[I1](reduce1_acc, reduce1_acc);
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
r1_m_host(m) = ck::type_convert<R1DataType>(reduce1_acc);
}
e_device_buf.FromDevice(e_m_n.mData.data());
Tensor<EDataType> e_m_n_host_converted(e_m_n_host);
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<EDataType> e_m_n_device_converted(e_m_n);
pass = ck::utils::check_err(e_m_n_device_converted.mData,
e_m_n_host_converted.mData,
"Error: Incorrect results c",
1e-2,
1e-2);
}
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host_converted.mData, "Error: Incorrect results c", 1e-2, 1e-2);
}
r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data());
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
if(pass)
{
std::cout << "Success!" << std::endl;
}
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
DumpGemmReduceMeanSquareMeanPerf<ADataType, BDataType, EDataType, R0DataType, R1DataType>(
ave_time, M, N, K);
}
return pass;
}
...@@ -163,9 +163,9 @@ int main(int argc, char* argv[]) ...@@ -163,9 +163,9 @@ int main(int argc, char* argv[])
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
{ {
conv_param.K_, // g conv_param.K_, // g
0, // k 0, // n
1, // c 1, // k
0 // x 0 // wo
}); });
const auto residual_g_n_k_wos_desc = HostTensorDescriptor( const auto residual_g_n_k_wos_desc = HostTensorDescriptor(
......
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using B0DataType = BF16;
using B1DataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = BF16;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_gemm_example.inc"
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
...@@ -121,6 +121,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -121,6 +121,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
AElementOp, AElementOp,
B0ElementOp, B0ElementOp,
CElementOp>; CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType, B1DataType,
CDataType, CDataType,
...@@ -129,244 +130,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -129,244 +130,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
int main(int argc, char* argv[]) #include "run_batched_gemm_gemm_example.inc"
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 17)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
exit(0);
}
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// C_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
if(do_verification)
{
// Output of Gemm0 is input A of Gemm1
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
}
return 0; int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F32;
using B0DataType = F32;
using B1DataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F32;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
128, // Gemm1NPerBlock
16, // Gemm1KPerBlock
4, // AK1
4, // BK1
1, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
4,
4,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
1,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_gemm_example.inc"
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|------------|
Gemm0
|---------------------|
Gemm1
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t;
using B0DataType = ck::int4_t;
using B1DataType = ck::int4_t;
using KernelADataType = int8_t;
using KernelB0DataType = int8_t;
using KernelB1DataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = ck::int4_t;
using KernelCDataType = int8_t;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
KernelADataType,
KernelB0DataType,
KernelB1DataType,
KernelCDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
16, // AK1
16, // BK1
4, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
16,
16,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
4,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
ADataType,
AccDataType,
AElementOp,
B0ElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#define BUILD_INT4_EXAMPLE
#include "run_batched_gemm_gemm_example.inc"
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
This diff is collapsed.
This diff is collapsed.
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment