Commit 1dbdab56 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents d2e49b23 bac7df8f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16;
using BDataType = F16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32;
using R0DataType = F32;
using RsDataType = ck::Tuple<R0DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
template <typename ADataType, typename BDataType, typename EDataType, typename R0DataType>
void DumpPerf(float ave_time, int M, int N, int K)
{
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N + sizeof(R0DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{};
// Prepare GEMM, max
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
// [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool do_verification = true;
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
Tensor<EDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
auto reduce0_op = RsThreadReduceOp{}[I0];
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n));
reduce0_op(reduce0_acc, e_val);
};
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
}
e_device_buf.FromDevice(e_m_n.mData.data());
r0_device_buf.FromDevice(r0_m.mData.data());
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
}
bool time_kernel = true;
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
DumpPerf<ADataType, BDataType, EDataType, R0DataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16;
using BDataType = F16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32;
using R0DataType = F32;
using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough, Square>;
using RsElementOp = ck::Tuple<Div, Div>;
// ReduceOp
using R0ThreadReduceOp = ck::reduce::Add;
using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
template <typename ADataType,
typename BDataType,
typename EDataType,
typename R0DataType,
typename R1DataType>
void DumpPerf(float ave_time, int M, int N, int K)
{
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> r1_m(f_host_tensor_descriptor1d(M, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
// Prepare GEMM, mean, mean_square
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
// init reducetion buffer to 0
r0_device_buf.SetZero();
r1_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
bool do_verification = true;
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
auto I1 = ck::Number<1>{};
Tensor<EDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
Tensor<R1DataType> r1_m_host(r1_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
auto reduce0_op = R0ThreadReduceOp{};
auto reduce1_op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType square_e_val;
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n));
qs_element_op[I1](square_e_val, e_val);
reduce0_op(reduce0_acc, e_val);
reduce1_op(reduce1_acc, square_e_val);
}
rs_element_op[I0](reduce0_acc, reduce0_acc);
rs_element_op[I1](reduce1_acc, reduce1_acc);
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
r1_m_host(m) = ck::type_convert<R1DataType>(reduce1_acc);
}
e_device_buf.FromDevice(e_m_n.mData.data());
r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data());
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
}
bool time_kernel = true;
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
DumpPerf<ADataType, BDataType, EDataType, R0DataType, R1DataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp)
add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, "
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
// init reducetion buffer to 0
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
auto reduce0_op = ReduceOp0{};
auto reduce1_op = ReduceOp1{};
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType square_c_val;
square(square_c_val, c_val);
reduce0_op(reduce0_acc, c_val);
reduce1_op(reduce1_acc, square_c_val);
}
div(reduce0_acc, reduce0_acc);
div(reduce1_acc, reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(reduce0_m_device_result.mData,
reduce0_m_host_result.mData,
"Error: Incorrect results d0",
1e-4,
1e-5) &&
ck::utils::check_err(reduce1_m_device_result.mData,
reduce1_m_host_result.mData,
"Error: Incorrect results d1",
1e-3,
1e-5);
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
......@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceBatchedGemmInstance =
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
CDataType,
ReduceAccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -18,26 +18,21 @@ using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ck::Tuple<CDataType>,
Add,
2,
8,
8,
8,
8>;
ck::Sequence<8, 8>,
ck::Sequence<8>>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor,
int broadcastDim>
void host_broadcast2D(
......@@ -49,19 +44,19 @@ void host_broadcast2D(
{
for(int n = 0; n < N; ++n)
{
ComputeDataType Amn = ck::type_convert<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0;
auto Amn = A(m, n);
ctype Cmn = 0;
if constexpr(broadcastDim == 0)
{
ComputeDataType Bn = ck::type_convert<ComputeDataType>(B(n));
auto Bn = B(n);
functor(Cmn, Amn, Bn);
}
else
{
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
auto Bm = B(m);
functor(Cmn, Amn, Bm);
}
C(m, n) = ck::type_convert<ctype>(Cmn);
C(m, n) = Cmn;
}
}
}
......@@ -103,18 +98,19 @@ int main()
b_n_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {Stride, 1};
std::vector<ck::index_t> b_strides = {0, 1};
std::vector<ck::index_t> c_strides = {Stride, 1};
std::array<ck::index_t, 2> abc_lengths = {M, N};
std::array<ck::index_t, 2> a_strides = {Stride, 1};
std::array<ck::index_t, 2> b_strides = {0, 1};
std::array<ck::index_t, 2> c_strides = {Stride, 1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{});
abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -129,12 +125,8 @@ int main()
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride));
host_broadcast2D<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add,
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
host_broadcast2D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add, 0>(
host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -18,27 +18,19 @@ using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ck::Tuple<CDataType>,
Add,
3,
8,
1,
8,
8>;
ck::Sequence<1, 8>,
ck::Sequence<8>>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
void host_broadcast3D_am_bmnk(HostTensorC& C,
const HostTensorA& A,
const HostTensorB& B,
......@@ -51,11 +43,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C,
for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k)
{
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(m));
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(m, n, k));
ComputeDataType c_val = 0;
auto a_val = A(m);
auto b_val = B(m, n, k);
ctype c_val = 0;
functor(c_val, a_val, b_val);
C(m, n, k) = ck::type_convert<ctype>(c_val);
C(m, n, k) = c_val;
}
}
......@@ -85,25 +77,25 @@ int main()
b_m_n_k_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1, 0, 0};
std::vector<ck::index_t> b_strides{b_m_n_k.mDesc.GetStrides().begin(),
b_m_n_k.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c_m_n_k.mDesc.GetStrides().begin(),
c_m_n_k.mDesc.GetStrides().end()};
std::array<ck::index_t, 3> abc_lengths;
std::array<ck::index_t, 3> a_strides = {1, 0, 0};
std::array<ck::index_t, 3> b_strides;
std::array<ck::index_t, 3> c_strides;
std::copy(mnk.begin(), mnk.end(), abc_lengths.begin());
std::copy(
b_m_n_k.mDesc.GetStrides().begin(), b_m_n_k.mDesc.GetStrides().end(), b_strides.begin());
std::copy(
c_m_n_k.mDesc.GetStrides().begin(), c_m_n_k.mDesc.GetStrides().end(), c_strides.begin());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument =
broadcastAdd.MakeArgumentPointer(input,
output,
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{a_strides, b_strides},
{c_strides},
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -118,11 +110,8 @@ int main()
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
Tensor<CDataType> host_c_m_n_k(mnk);
host_broadcast3D_am_bmnk<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
host_broadcast3D_am_bmnk<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
......@@ -5,7 +5,7 @@
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -17,27 +17,19 @@ using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ck::Tuple<CDataType>,
Add,
1,
8,
8,
8,
8>;
ck::Sequence<8, 8>,
ck::Sequence<8>>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
void host_elementwise1D(
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor)
{
......@@ -45,11 +37,11 @@ void host_elementwise1D(
for(int m = 0; m < M; ++m)
{
ComputeDataType Am = ck::type_convert<ComputeDataType>(A(m));
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
ComputeDataType Cm = 0;
auto Am = A(m);
auto Bm = B(m);
ctype Cm = 0;
functor(Cm, Am, Bm);
C(m) = ck::type_convert<ctype>(Cm);
C(m) = Cm;
}
}
......@@ -83,18 +75,19 @@ int main()
b_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1};
std::vector<ck::index_t> b_strides = {1};
std::vector<ck::index_t> c_strides = {1};
std::array<ck::index_t, 1> abc_lengths = {M};
std::array<ck::index_t, 1> a_strides = {1};
std::array<ck::index_t, 1> b_strides = {1};
std::array<ck::index_t, 1> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -109,11 +102,8 @@ int main()
c_m_device_buf.FromDevice(c_m.mData.data());
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
host_elementwise1D<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m, a_m, b_m, M, Add{});
host_elementwise1D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c_m, a_m, b_m, M, Add{});
pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -18,27 +18,19 @@ using F32 = float;
using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
ABDataType,
CDataType,
EltwiseComputeDataType,
ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ck::Tuple<CDataType>,
Add,
4,
8,
8,
8,
8>;
ck::Sequence<8, 8>,
ck::Sequence<8>>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
void host_elementwise4D(HostTensorC& C,
const HostTensorA& A,
const HostTensorB& B,
......@@ -52,11 +44,11 @@ void host_elementwise4D(HostTensorC& C,
for(std::size_t h = 0; h < shape[2]; ++h)
for(std::size_t w = 0; w < shape[3]; ++w)
{
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(n, c, h, w));
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(n, c, h, w));
ComputeDataType c_val = 0;
auto a_val = A(n, c, h, w);
auto b_val = B(n, c, h, w);
ctype c_val = 0;
functor(c_val, a_val, b_val);
C(n, c, h, w) = ck::type_convert<ctype>(c_val);
C(n, c, h, w) = c_val;
}
}
......@@ -85,23 +77,24 @@ int main()
b_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()};
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()};
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()};
std::array<ck::index_t, 4> abc_lengths;
std::array<ck::index_t, 4> a_strides;
std::array<ck::index_t, 4> b_strides;
std::array<ck::index_t, 4> c_strides;
std::copy(nchw.begin(), nchw.end(), abc_lengths.begin());
std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin());
std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin());
std::copy(c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end(), c_strides.begin());
auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument =
broadcastAdd.MakeArgumentPointer(input,
output,
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
{{a_strides}, b_strides},
{c_strides},
Add{});
auto argument = broadcastAdd.MakeArgumentPointer(
abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!");
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
};
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
......@@ -116,11 +109,8 @@ int main()
c_device_buf.FromDevice(c.mData.data());
Tensor<CDataType> host_c(nchw);
host_elementwise4D<Tensor<ABDataType>,
Tensor<ABDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c, a, b, nchw, Add{});
host_elementwise4D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c, a, b, nchw, Add{});
pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
......@@ -9,8 +9,8 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -28,57 +28,64 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using BiasDataType = F32;
using D0DataType = F16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using R0DataType = F32;
using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
// Layout
using ALayout = Row;
using BLayout = Col;
using D1Layout = Row;
using ELayout = D1Layout;
// Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::Relu;
using D0ElementOp = PassThrough;
using ReduceSumOp = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using CDEElementOp = AddReluAdd;
using QsElementOp = ck::Tuple<PassThrough, Square>;
using RsElementOp = ck::Tuple<Div, Div>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
// ReduceOp
using R0ThreadReduceOp = ck::reduce::Add;
using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
EDataType,
GemmAccDataType,
AElementOp,
BElementOp,
......@@ -87,23 +94,18 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
ReduceDataType,
ReduceDataType,
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<EDataType,
R0DataType,
R1DataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
NormalizeComputeDataType,
BetaDataType>, // x(gemm_out), mean, meansquare, gamma, beta
ck::Tuple<LayerNormOutDataType>, // y
NormalizeFunctor,
2,
8,
8, // scalarPerVector: gemm_out
1, // scalarPerVector: reduce_mean
1, // scalarPerVector: reduce_mean_square
8, // scalarPerVector: Gamma
8, // scalarPerVector: Beta
8>; // scalarPerVector: LayerNorm_out
8, // MPerthread
ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta
ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out)
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
......@@ -124,41 +126,31 @@ auto f_host_tensor_descriptor2d =
}
};
template <typename CDataType,
typename ReduceDataType,
typename AccDataType,
typename BiasDataType,
typename D0DataType,
typename A_functor,
typename B_functor,
typename C_functor,
typename C1_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n,
const Tensor<BiasDataType>& bias_n,
const Tensor<D0DataType>& c1_m_n,
const Tensor<BDataType>& b_k_n,
const Tensor<D0DataType>& bias_n,
const Tensor<D1DataType>& d1_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n,
A_functor a_element_op,
B_functor b_element_op,
C_functor c_element_op,
C1_functor c1_element_op,
const Tensor<BetaDataType>& beta_n,
AElementOp a_element_op,
BElementOp b_element_op,
CDEElementOp cde_element_op,
int M,
int N)
{
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
int StrideE = N;
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = Div{N};
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
......@@ -166,38 +158,32 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n)
{
AccDataType acc = ck::type_convert<AccDataType>(c_m_n(m, n)) +
ck::type_convert<AccDataType>(bias_n(n));
AccDataType c1 = ck::type_convert<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc);
c1_element_op(c1, c1);
acc += c1;
c_m_n(m, n) = ck::type_convert<CDataType>(acc);
auto acc = ck::type_convert<GemmAccDataType>(e_m_n(m, n));
cde_element_op(e_m_n(m, n), acc, bias_n(n), d1_m_n(m, n));
}
// reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{};
auto r0Op = R0ThreadReduceOp{};
auto r1Op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m)
{
auto mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>();
auto mean_acc = r0Op.GetIdentityValue<ReduceAccDataType>();
auto mean_square_acc = r1Op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
AccDataType c_val = ck::type_convert<AccDataType>(c_m_n(m, n));
AccDataType square_c_val = 0;
UnarySquareElementOp{}(square_c_val, c_val);
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n(m, n));
ReduceAccDataType square_e_val = 0;
Square{}(square_e_val, e_val);
reduceSumOpInst(mean_acc, c_val);
reduceSumOpInst(square_mean_acc, square_c_val);
r0Op(mean_acc, e_val);
r1Op(mean_square_acc, square_e_val);
}
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
averageOpInst(mean_square_acc, mean_square_acc);
mean_m(m) = ck::type_convert<R0DataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<R1DataType>(mean_square_acc);
}
// LayerNorm
......@@ -206,24 +192,20 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{
for(int n = 0; n < N; ++n)
{
AccDataType out_acc = 0;
layerNormInst(out_acc,
ck::type_convert<AccDataType>(c_m_n(m, n)),
ck::type_convert<AccDataType>(mean_m(m)),
ck::type_convert<AccDataType>(meanSquare_m(m)),
ck::type_convert<AccDataType>(gamma_n(n)),
ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
LayerNormOutDataType out_val = 0;
layerNormInst(out_val, e_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
out_m_n(m, n) = out_val;
}
}
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename BiasDataType,
typename EDataType,
typename D0DataType,
typename ReduceDataType,
typename D1DataType,
typename R0DataType,
typename R1DataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
......@@ -231,12 +213,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
sizeof(EDataType) * M * N + sizeof(D0DataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
std::size_t normalize_num_byte = sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
......@@ -259,37 +241,37 @@ int main()
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideD0 = 1024;
ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<D0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, ELayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_Mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
bias_n.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-5, 5});
bias_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-5, 5});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) *
reduceMean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(D0DataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize());
DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) *
r1_MeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
......@@ -298,104 +280,94 @@ int main()
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data());
d0_device_buf.ToDevice(c1_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto d_element_op = D0ElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
// Prepare GEMM, mean, mean_square
auto gemmReduce = DeviceOpInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
auto gemmReduce_argument = gemmReduce.MakeArgument(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(),
p_reduces,
{bias_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
{r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
StrideC,
{StrideD0},
gemm_element_ops,
{&d_element_op},
reduce_in_element_ops,
reduce_out_element_ops);
{StrideD0, StrideD1},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
reduceMean_device_buf.SetZero();
reduceMeanSquare_device_buf.SetZero();
// init reducetion buffer to 0
r0_Mean_device_buf.SetZero();
r1_MeanSquare_device_buf.SetZero();
// Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
std::array<const void*, 5> input = {e_device_buf.GetDeviceBuffer(),
r0_Mean_device_buf.GetDeviceBuffer(),
r1_MeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 2> xyLengths = {M, N};
std::array<ck::index_t, 2> xyStrides = {StrideE, 1};
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument(input,
auto normalize_argument_ptr =
normalize.MakeArgumentPointer(xyLengths,
{xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{xyStrides},
input,
output,
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument))
if(!normalize.IsSupportedArgument(normalize_argument_ptr.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise instance, exiting!");
throw std::runtime_error(
"The runtime parameters seems not supported by the device, exiting!");
}
// run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false});
bool pass = true;
{
// verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n,
host_gemm_layernorm(host_layerNorm_m_n,
a_m_k,
b_k_n,
bias_n,
c1_m_n,
d1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
d_element_op,
cde_element_op,
M,
N);
......@@ -414,15 +386,16 @@ int main()
float gemm_reduce_mean_reduce_square_mean_ave_time =
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
float normalize_ave_time =
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
BiasDataType,
EDataType,
D0DataType,
ReduceDataType,
D1DataType,
R0DataType,
R1DataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(
......
......@@ -9,8 +9,8 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -28,78 +28,83 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using R0DataType = F32;
using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceSumOp = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// Layout
using ALayout = Row;
using BLayout = Col;
using D1Layout = Row;
using ELayout = D1Layout;
// Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough, Square>;
using RsElementOp = ck::Tuple<Div, Div>;
// ReduceOp
using R0ThreadReduceOp = ck::reduce::Add;
using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
EDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CElementOp>;
PassThrough>;
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
ReduceDataType,
ReduceDataType,
using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<EDataType,
R0DataType,
R1DataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
NormalizeComputeDataType,
BetaDataType>, // x(gemm_out), mean,
// meansquare,
// gamma, beta
ck::Tuple<LayerNormOutDataType>, // y
NormalizeFunctor,
2,
8,
8, // scalarPerVector: gemm_out
1, // scalarPerVector: reduce_mean
1, // scalarPerVector: reduce_mean_square
8, // scalarPerVector: Gamma
8, // scalarPerVector: Beta
8>; // scalarPerVector: LayerNorm_out
8, // MPerthread
ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta
ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out)
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
......@@ -120,60 +125,53 @@ auto f_host_tensor_descriptor2d =
}
};
template <typename CDataType,
typename ReduceDataType,
typename A_functor,
typename B_functor,
typename C_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n,
const Tensor<BDataType>& b_k_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
A_functor a_element_op,
B_functor b_element_op,
C_functor c_element_op,
AElementOp a_element_op,
BElementOp b_element_op,
CDEElementOp c_element_op,
int M,
int N)
{
using out_type = ck::remove_reference_t<decltype(out_m_n(0, 0))>;
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
int StrideE = N;
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = Div{N};
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op);
ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
// reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{};
auto r0Op = R0ThreadReduceOp{};
auto r1Op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m)
{
auto mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
auto mean_acc = r0Op.GetIdentityValue<ReduceAccDataType>();
auto mean_square_acc = r1Op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
auto square_c_val = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
UnarySquareElementOp{}(square_c_val, c_val);
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n(m, n));
ReduceAccDataType square_e_val = 0;
Square{}(square_e_val, e_val);
reduceSumOpInst(mean_acc, c_val);
reduceSumOpInst(square_mean_acc, square_c_val);
r0Op(mean_acc, e_val);
r1Op(mean_square_acc, square_e_val);
}
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
averageOpInst(mean_square_acc, mean_square_acc);
mean_m(m) = ck::type_convert<R0DataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<R1DataType>(mean_square_acc);
}
// LayerNorm
......@@ -182,22 +180,18 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{
for(int n = 0; n < N; ++n)
{
float out_f32 = 0;
layerNormInst(out_f32,
static_cast<float>(c_m_n(m, n)),
static_cast<float>(mean_m(m)),
static_cast<float>(meanSquare_m(m)),
static_cast<float>(gamma_n(n)),
static_cast<float>(beta_n(n)));
out_m_n(m, n) = static_cast<out_type>(out_f32);
LayerNormOutDataType out_val = 0;
layerNormInst(out_val, e_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
out_m_n(m, n) = out_val;
}
}
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ReduceDataType,
typename EDataType,
typename R0DataType,
typename R1DataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
......@@ -205,11 +199,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
std::size_t normalize_num_btye = sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
......@@ -232,17 +226,17 @@ int main()
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_Mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
......@@ -251,11 +245,10 @@ int main()
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) *
reduceMean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize());
DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) *
r1_MeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
......@@ -268,38 +261,31 @@ int main()
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmReduceInstance{};
// Prepare GEMM, mean, mean_square
auto gemmReduce = DeviceOpInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
auto gemmReduce_argument = gemmReduce.MakeArgument(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
e_device_buf.GetDeviceBuffer(),
{r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops);
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{
......@@ -308,54 +294,54 @@ int main()
"not support this GEMM problem");
}
reduceMean_device_buf.SetZero();
reduceMeanSquare_device_buf.SetZero();
r0_Mean_device_buf.SetZero();
r1_MeanSquare_device_buf.SetZero();
// Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(),
std::array<const void*, 5> input = {e_device_buf.GetDeviceBuffer(),
r0_Mean_device_buf.GetDeviceBuffer(),
r1_MeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 2> xyLengths = {M, N};
std::array<ck::index_t, 2> xyStrides = {StrideE, 1};
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument(input,
auto normalize_argument_ptr =
normalize.MakeArgumentPointer(xyLengths,
{xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{xyStrides},
input,
output,
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument))
if(!normalize.IsSupportedArgument(normalize_argument_ptr.get()))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise instance, exiting!");
throw std::runtime_error(
"The runtime parameters seems not supported by the device, exiting");
}
// run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false});
bool pass = true;
{
// verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n,
host_gemm_layernorm(host_layerNorm_m_n,
a_m_k,
b_k_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
cde_element_op,
M,
N);
......@@ -374,13 +360,14 @@ int main()
float gemm_reduce_mean_reduce_square_mean_ave_time =
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
float normalize_ave_time =
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
ReduceDataType,
EDataType,
R0DataType,
R1DataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl
// clang-format off
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
const int M = 256;
const int N = 128;
const int K = 64;
const int stride_A = K;
const int stride_B = K;
const int batch_stride_A = M * K;
const int batch_stride_B = K * N;
const int G0 = 16;
const int G1 = 8;
const int batch_count = G0 * G1;
// output layout - [G0, M, G1, N]
const int stride_G0 = M * G1 * N;
const int stride_G1 = N;
const int stride_M = G1 * N;
const int stride_N = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
// GEMM shape
ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
auto f_host_e_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_,
std::size_t M_,
std::size_t N_,
std::size_t stride_G0_,
std::size_t stride_G1_,
std::size_t stride_M_,
std::size_t stride_N_) {
return HostTensorDescriptor(
std::vector<std::size_t>({G0_, G1_, M_, N_}),
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
};
Tensor<EDataType> e_g0_g1_m_n_host_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<EDataType> e_g0_g1_m_n_device_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
batch_stride_A,
batch_stride_B,
batched_gemm_e_permute_desc,
batch_count,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
for(int g0 = 0; g0 < G0; g0++)
{
for(int g1 = 0; g1 < G1; g1++)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData,
e_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c");
}
return pass ? 0 : 1;
}
add_example_executable(example_gemm_bias_e_permute_xdl_fp16 gemm_bias_e_permute_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 1;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 3;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M2_N3_K1::Argument;
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c;
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M3_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t G0 = 1;
ck::index_t M0 = 4;
ck::index_t M1 = 256;
ck::index_t N0 = 4;
ck::index_t N1 = 16;
ck::index_t N2 = 32;
ck::index_t K0 = 256;
// A[M0, M1, M2, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{M0 * M1 * K0, M1 * K0, K0, 1};
// B[N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, N0, N1, N2, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{N0 * N1 * N2 * K0, N1 * N2 * K0, N2 * K0, K0, 1};
// D[N0, M0, N1, M1, N2]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2};
std::vector<ck::index_t> d_gs_ms_ns_strides{N0 * N1 * N2, 0, 0, N1 * N2, N2, 1};
// E[N0, M0, N1, M1, N2]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2};
std::vector<ck::index_t> e_gs_ms_ns_strides{
M0 * M1 * N0 * N1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks,
b_gs_ns_ks,
c_gs_ms_ns_host_result,
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1)
{
for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n2)
{
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
d_gs_ms_ns(g0, m0, m1, n0, n1, n2));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 1;
static constexpr ck::index_t NumDimM = 3;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 3 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M3_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G1_M3_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t G0 = 1;
ck::index_t M0 = 4;
ck::index_t M1 = 8;
ck::index_t M2 = 256;
ck::index_t N0 = 32;
ck::index_t N1 = 128;
ck::index_t K0 = 1024;
// A[M0, M1, M2, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, M0, M1, M2, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{M0 * M1 * M2 * K0, M1 * M2 * K0, M2 * K0, K0, 1};
// B[N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{N0 * N1 * K0, N1 * K0, K0, 1};
// D[M0, N0, M1, N1, M2]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, M0, M1, M2, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{N0 * N1, 0, 0, 0, N1, 1};
// E[M1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, M0, M1, M2, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
M0 * M1 * M2 * N1 * N0, N0 * M1 * N1, N1, M0 * N0 * M1 * N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin(),
e_gs_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks,
b_gs_ns_ks,
c_gs_ms_ns_host_result,
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1)
{
for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m2)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
c_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
d_gs_ms_ns(g0, m0, m1, m2, n0, n1));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasEPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t M0 = 4;
ck::index_t M1 = 32;
ck::index_t M2 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
// GEMM shape
ck::index_t M = M0 * M1 * M2;
ck::index_t N = N0 * N1;
ck::index_t K = 128;
ck::index_t stride_A = K;
ck::index_t stride_B = K;
#if 1
// E = [M0, N0, M1, N1, M2]
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
#else
// D = [0, 0, 0, N0, N1]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
// E = [M0, M1, M2, N0, N1]
ck::index_t stride_E_M0 = M1 * M2 * N0 * N1;
ck::index_t stride_E_M1 = M2 * N0 * N1;
ck::index_t stride_E_M2 = N0 * N1;
ck::index_t stride_E_N0 = N1;
ck::index_t stride_E_N1 = 1;
#endif
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{
M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1};
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{
M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
auto f_host_de_tensor_descriptor =
[](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) {
std::size_t m0 = de_grid_desc.M0_;
std::size_t m1 = de_grid_desc.M1_;
std::size_t m2 = de_grid_desc.M2_;
std::size_t n0 = de_grid_desc.N0_;
std::size_t n1 = de_grid_desc.N1_;
std::size_t stride_m0 = de_grid_desc.stride_M0_;
std::size_t stride_m1 = de_grid_desc.stride_M1_;
std::size_t stride_m2 = de_grid_desc.stride_M2_;
std::size_t stride_n0 = de_grid_desc.stride_N0_;
std::size_t stride_n1 = de_grid_desc.stride_N1_;
return HostTensorDescriptor(
std::vector<std::size_t>({m0, m1, m2, n0, n1}),
std::vector<std::size_t>({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1}));
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
Tensor<DDataType> d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl;
std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) *
d_m0_m1_m2_n0_n1.mDesc.GetElementSpaceSize());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(
sizeof(EDataType) * e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< device_op.GetTypeString() << std::endl;
if(do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m0 = 0; m0 < M0; ++m0)
for(int m1 = 0; m1 < M1; ++m1)
for(int m2 = 0; m2 < M2; ++m2)
for(int n0 = 0; n0 < N0; ++n0)
for(int n1 = 0; n1 < N1; ++n1)
{
int m = m0 * M1 * M2 + m1 * M2 + m2;
int n = n0 * N1 + n1;
cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1),
ck::type_convert<EDataType>(c_m_n(m, n)),
d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1));
}
e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data());
return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData,
e_m0_m1_m2_n0_n1_host_result.mData)
? 0
: 1;
}
return 0;
}
......@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -29,7 +29,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
using DeviceInstance = ck::tensor_operation::device::DeviceLayernormImpl<XDataType,
GammaDataType,
BetaDataType,
AccDataType,
......@@ -46,7 +46,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
8, // SrcScalarPerVector
8, // GammaScalarPerVector
8, // BetaScalarPerVector
1>; // OutScalarPerVector
8>; // OutScalarPerVector
int main()
{
......@@ -90,6 +90,7 @@ int main()
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
{1},
1e-4,
x_dev.GetDeviceBuffer(),
......
add_example_executable(example_grouped_gemm_bias_xdl_fp16 grouped_gemm_bias_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<std::array<const void*, 1>> p_ds;
std::vector<void*> p_c;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
int stride_A = K;
int stride_B = K;
int stride_C = N;
std::vector<ck::index_t> stride_Ds = {0};
gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, stride_Ds});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<DDataType>> d_tensors;
std::vector<Tensor<EDataType>> e_host_tensors;
std::vector<Tensor<EDataType>> e_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
e_host_tensors.reserve(group_count);
e_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d_tensors_device,
e_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d_tensors_device.reserve(group_count);
e_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
d_tensors.push_back(Tensor<DDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_Ds_[0], ELayout{})));
e_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
e_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << e_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
switch(init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_tensors[i].mDesc.GetElementSpaceSize()));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d_tensors_device[i]->ToDevice(d_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_ds.push_back({d_tensors_device[i]->GetDeviceBuffer()});
p_c.push_back(e_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, cde_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
e_host_tensors[i],
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < gemm_descs[i].M_; ++m)
{
for(int n = 0; n < gemm_descs[i].N_; ++n)
{
cde_element_op(
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(e_device_tensors[i].mData, e_host_tensors[i].mData);
}
}
return pass ? 0 : 1;
}
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment