Commit 1dbdab56 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents d2e49b23 bac7df8f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16;
using BDataType = F16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32;
using R0DataType = F32;
using RsDataType = ck::Tuple<R0DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough>;
using RsElementOp = ck::Tuple<PassThrough>;
// ReduceOp
using RsThreadReduceOp = ck::Tuple<ck::reduce::Max>;
using RsGlobalReduceOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
template <typename ADataType, typename BDataType, typename EDataType, typename R0DataType>
void DumpPerf(float ave_time, int M, int N, int K)
{
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N + sizeof(R0DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{};
// Prepare GEMM, max
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
// [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool do_verification = true;
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
Tensor<EDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
auto reduce0_op = RsThreadReduceOp{}[I0];
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n));
reduce0_op(reduce0_acc, e_val);
};
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
}
e_device_buf.FromDevice(e_m_n.mData.data());
r0_device_buf.FromDevice(r0_m.mData.data());
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
}
bool time_kernel = true;
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
DumpPerf<ADataType, BDataType, EDataType, R0DataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16;
using BDataType = F16;
using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32;
using R0DataType = F32;
using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
// Layout
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
// Elementwise op
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
using QsElementOp = ck::Tuple<PassThrough, Square>;
using RsElementOp = ck::Tuple<Div, Div>;
// ReduceOp
using R0ThreadReduceOp = ck::reduce::Add;
using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
template <typename ADataType,
typename BDataType,
typename EDataType,
typename R0DataType,
typename R1DataType>
void DumpPerf(float ave_time, int M, int N, int K)
{
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(R1DataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gemm_gb_per_sec
<< " GB/s, " << std::endl;
}
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
int main()
{
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<R0DataType> r0_m(f_host_tensor_descriptor1d(M, 1));
Tensor<R1DataType> r1_m(f_host_tensor_descriptor1d(M, 1));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem r0_device_buf(sizeof(R0DataType) * r0_m.mDesc.GetElementSpaceSize());
DeviceMem r1_device_buf(sizeof(R1DataType) * r1_m.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
// Prepare GEMM, mean, mean_square
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{},
e_device_buf.GetDeviceBuffer(),
{r0_device_buf.GetDeviceBuffer(), r1_device_buf.GetDeviceBuffer()},
M,
N,
K,
StrideA,
StrideB,
{},
StrideE,
a_element_op,
b_element_op,
cde_element_op,
qs_element_op,
rs_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
// init reducetion buffer to 0
r0_device_buf.SetZero();
r1_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
bool do_verification = true;
bool pass = true;
if(do_verification)
{
auto I0 = ck::Number<0>{};
auto I1 = ck::Number<1>{};
Tensor<EDataType> e_m_n_host(e_m_n.mDesc);
Tensor<R0DataType> r0_m_host(r0_m.mDesc);
Tensor<R1DataType> r1_m_host(r1_m.mDesc);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
auto reduce0_op = R0ThreadReduceOp{};
auto reduce1_op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType square_e_val;
auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n_host(m, n));
qs_element_op[I1](square_e_val, e_val);
reduce0_op(reduce0_acc, e_val);
reduce1_op(reduce1_acc, square_e_val);
}
rs_element_op[I0](reduce0_acc, reduce0_acc);
rs_element_op[I1](reduce1_acc, reduce1_acc);
r0_m_host(m) = ck::type_convert<R0DataType>(reduce0_acc);
r1_m_host(m) = ck::type_convert<R1DataType>(reduce1_acc);
}
e_device_buf.FromDevice(e_m_n.mData.data());
r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data());
pass = ck::utils::check_err(
e_m_n.mData, e_m_n_host.mData, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r0_m.mData, r0_m_host.mData, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(
r1_m.mData, r1_m_host.mData, "Error: Incorrect results d1", 1e-2, 1e-2);
}
bool time_kernel = true;
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
DumpPerf<ADataType, BDataType, EDataType, R0DataType, R1DataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp)
add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using ReduceOp0 = ck::reduce::Add;
using ReduceOp1 = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using ReduceGlobalMemOps =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
GemmAccDataType,
AElementOp,
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, "
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
if(argc == 1)
{
// do nothing
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
nullptr,
{},
c_device_buf.GetDeviceBuffer(),
p_reduces,
M,
N,
K,
StrideA,
StrideB,
StrideC,
{},
gemm_element_ops,
{},
reduce_in_element_ops,
reduce_out_element_ops);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
// init reducetion buffer to 0
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
// will not be correct. need to set time_kernel = false for correctness test
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
auto reduce0_op = ReduceOp0{};
auto reduce1_op = ReduceOp1{};
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType square_c_val;
square(square_c_val, c_val);
reduce0_op(reduce0_acc, c_val);
reduce1_op(reduce1_acc, square_c_val);
}
div(reduce0_acc, reduce0_acc);
div(reduce1_acc, reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
}
pass = ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData,
"Error: Incorrect results c") &&
ck::utils::check_err(reduce0_m_device_result.mData,
reduce0_m_host_result.mData,
"Error: Incorrect results d0",
1e-4,
1e-5) &&
ck::utils::check_err(reduce1_m_device_result.mData,
reduce1_m_host_result.mData,
"Error: Incorrect results d1",
1e-3,
1e-5);
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
}
...@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc ...@@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance =
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
CDataType,
ReduceAccDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -16,28 +16,23 @@ ...@@ -16,28 +16,23 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ABDataType = F16; using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ABDataType, ck::Tuple<CDataType>,
CDataType, Add,
EltwiseComputeDataType, 2,
Add, 8,
2, ck::Sequence<8, 8>,
8, ck::Sequence<8>>;
8,
8,
8>;
template <typename HostTensorA, template <typename HostTensorA,
typename HostTensorB, typename HostTensorB,
typename HostTensorC, typename HostTensorC,
typename ComputeDataType,
typename Functor, typename Functor,
int broadcastDim> int broadcastDim>
void host_broadcast2D( void host_broadcast2D(
...@@ -49,19 +44,19 @@ void host_broadcast2D( ...@@ -49,19 +44,19 @@ void host_broadcast2D(
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ComputeDataType Amn = ck::type_convert<ComputeDataType>(A(m, n)); auto Amn = A(m, n);
ComputeDataType Cmn = 0; ctype Cmn = 0;
if constexpr(broadcastDim == 0) if constexpr(broadcastDim == 0)
{ {
ComputeDataType Bn = ck::type_convert<ComputeDataType>(B(n)); auto Bn = B(n);
functor(Cmn, Amn, Bn); functor(Cmn, Amn, Bn);
} }
else else
{ {
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m)); auto Bm = B(m);
functor(Cmn, Amn, Bm); functor(Cmn, Amn, Bm);
} }
C(m, n) = ck::type_convert<ctype>(Cmn); C(m, n) = Cmn;
} }
} }
} }
...@@ -103,18 +98,19 @@ int main() ...@@ -103,18 +98,19 @@ int main()
b_n_device_buf.GetDeviceBuffer()}; b_n_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {Stride, 1}; std::array<ck::index_t, 2> abc_lengths = {M, N};
std::vector<ck::index_t> b_strides = {0, 1}; std::array<ck::index_t, 2> a_strides = {Stride, 1};
std::vector<ck::index_t> c_strides = {Stride, 1}; std::array<ck::index_t, 2> b_strides = {0, 1};
std::array<ck::index_t, 2> c_strides = {Stride, 1};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer( auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{}); abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error(
"DeviceBinaryElementwise instance, exiting!"); "The runtime parameters seems not supported by the device instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -129,12 +125,8 @@ int main() ...@@ -129,12 +125,8 @@ int main()
c_m_n_device_buf.FromDevice(c_m_n.mData.data()); c_m_n_device_buf.FromDevice(c_m_n.mData.data());
Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride));
host_broadcast2D<Tensor<ABDataType>, host_broadcast2D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add, 0>(
Tensor<ABDataType>, host_c_m_n, a_m_n, b_n, M, N, Add{});
Tensor<CDataType>,
EltwiseComputeDataType,
Add,
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3); c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -16,29 +16,21 @@ ...@@ -16,29 +16,21 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ABDataType = F16; using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ABDataType, ck::Tuple<CDataType>,
CDataType, Add,
EltwiseComputeDataType, 3,
Add, 8,
3, ck::Sequence<1, 8>,
8, ck::Sequence<8>>;
1,
8, template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
8>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_broadcast3D_am_bmnk(HostTensorC& C, void host_broadcast3D_am_bmnk(HostTensorC& C,
const HostTensorA& A, const HostTensorA& A,
const HostTensorB& B, const HostTensorB& B,
...@@ -51,11 +43,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C, ...@@ -51,11 +43,11 @@ void host_broadcast3D_am_bmnk(HostTensorC& C,
for(std::size_t n = 0; n < shape[1]; ++n) for(std::size_t n = 0; n < shape[1]; ++n)
for(std::size_t k = 0; k < shape[2]; ++k) for(std::size_t k = 0; k < shape[2]; ++k)
{ {
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(m)); auto a_val = A(m);
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(m, n, k)); auto b_val = B(m, n, k);
ComputeDataType c_val = 0; ctype c_val = 0;
functor(c_val, a_val, b_val); functor(c_val, a_val, b_val);
C(m, n, k) = ck::type_convert<ctype>(c_val); C(m, n, k) = c_val;
} }
} }
...@@ -85,25 +77,25 @@ int main() ...@@ -85,25 +77,25 @@ int main()
b_m_n_k_device_buf.GetDeviceBuffer()}; b_m_n_k_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1, 0, 0}; std::array<ck::index_t, 3> abc_lengths;
std::vector<ck::index_t> b_strides{b_m_n_k.mDesc.GetStrides().begin(), std::array<ck::index_t, 3> a_strides = {1, 0, 0};
b_m_n_k.mDesc.GetStrides().end()}; std::array<ck::index_t, 3> b_strides;
std::vector<ck::index_t> c_strides{c_m_n_k.mDesc.GetStrides().begin(), std::array<ck::index_t, 3> c_strides;
c_m_n_k.mDesc.GetStrides().end()};
std::copy(mnk.begin(), mnk.end(), abc_lengths.begin());
std::copy(
b_m_n_k.mDesc.GetStrides().begin(), b_m_n_k.mDesc.GetStrides().end(), b_strides.begin());
std::copy(
c_m_n_k.mDesc.GetStrides().begin(), c_m_n_k.mDesc.GetStrides().end(), c_strides.begin());
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = auto argument = broadcastAdd.MakeArgumentPointer(
broadcastAdd.MakeArgumentPointer(input, abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
output,
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
{a_strides, b_strides},
{c_strides},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error(
"DeviceBinaryElementwise instance, exiting!"); "The runtime parameters seems not supported by the device instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -118,11 +110,8 @@ int main() ...@@ -118,11 +110,8 @@ int main()
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data()); c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
Tensor<CDataType> host_c_m_n_k(mnk); Tensor<CDataType> host_c_m_n_k(mnk);
host_broadcast3D_am_bmnk<Tensor<ABDataType>, host_broadcast3D_am_bmnk<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
Tensor<ABDataType>, host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3); c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <cstdlib> #include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -15,29 +15,21 @@ ...@@ -15,29 +15,21 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ABDataType = F16; using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ABDataType, ck::Tuple<CDataType>,
CDataType, Add,
EltwiseComputeDataType, 1,
Add, 8,
1, ck::Sequence<8, 8>,
8, ck::Sequence<8>>;
8,
8, template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
8>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_elementwise1D( void host_elementwise1D(
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor) HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor)
{ {
...@@ -45,11 +37,11 @@ void host_elementwise1D( ...@@ -45,11 +37,11 @@ void host_elementwise1D(
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ComputeDataType Am = ck::type_convert<ComputeDataType>(A(m)); auto Am = A(m);
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m)); auto Bm = B(m);
ComputeDataType Cm = 0; ctype Cm = 0;
functor(Cm, Am, Bm); functor(Cm, Am, Bm);
C(m) = ck::type_convert<ctype>(Cm); C(m) = Cm;
} }
} }
...@@ -83,18 +75,19 @@ int main() ...@@ -83,18 +75,19 @@ int main()
b_m_device_buf.GetDeviceBuffer()}; b_m_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides = {1}; std::array<ck::index_t, 1> abc_lengths = {M};
std::vector<ck::index_t> b_strides = {1}; std::array<ck::index_t, 1> a_strides = {1};
std::vector<ck::index_t> c_strides = {1}; std::array<ck::index_t, 1> b_strides = {1};
std::array<ck::index_t, 1> c_strides = {1};
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = broadcastAdd.MakeArgumentPointer( auto argument = broadcastAdd.MakeArgumentPointer(
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{}); abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error(
"DeviceBinaryElementwise instance, exiting!"); "The runtime parameters seems not supported by the device instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -109,11 +102,8 @@ int main() ...@@ -109,11 +102,8 @@ int main()
c_m_device_buf.FromDevice(c_m.mData.data()); c_m_device_buf.FromDevice(c_m.mData.data());
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1)); Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
host_elementwise1D<Tensor<ABDataType>, host_elementwise1D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
Tensor<ABDataType>, host_c_m, a_m, b_m, M, Add{});
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c_m, a_m, b_m, M, Add{});
pass &= ck::utils::check_err( pass &= ck::utils::check_err(
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3); c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_binary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -16,29 +16,21 @@ ...@@ -16,29 +16,21 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using ABDataType = F16; using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::element_wise::Add; using Add = ck::tensor_operation::element_wise::Add;
using DeviceElementwiseAddInstance = using DeviceElementwiseAddInstance =
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType, ck::tensor_operation::device::DeviceElementwise<ck::Tuple<ABDataType, ABDataType>,
ABDataType, ck::Tuple<CDataType>,
CDataType, Add,
EltwiseComputeDataType, 4,
Add, 8,
4, ck::Sequence<8, 8>,
8, ck::Sequence<8>>;
8,
8, template <typename HostTensorA, typename HostTensorB, typename HostTensorC, typename Functor>
8>;
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor>
void host_elementwise4D(HostTensorC& C, void host_elementwise4D(HostTensorC& C,
const HostTensorA& A, const HostTensorA& A,
const HostTensorB& B, const HostTensorB& B,
...@@ -52,11 +44,11 @@ void host_elementwise4D(HostTensorC& C, ...@@ -52,11 +44,11 @@ void host_elementwise4D(HostTensorC& C,
for(std::size_t h = 0; h < shape[2]; ++h) for(std::size_t h = 0; h < shape[2]; ++h)
for(std::size_t w = 0; w < shape[3]; ++w) for(std::size_t w = 0; w < shape[3]; ++w)
{ {
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(n, c, h, w)); auto a_val = A(n, c, h, w);
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(n, c, h, w)); auto b_val = B(n, c, h, w);
ComputeDataType c_val = 0; ctype c_val = 0;
functor(c_val, a_val, b_val); functor(c_val, a_val, b_val);
C(n, c, h, w) = ck::type_convert<ctype>(c_val); C(n, c, h, w) = c_val;
} }
} }
...@@ -85,23 +77,24 @@ int main() ...@@ -85,23 +77,24 @@ int main()
b_device_buf.GetDeviceBuffer()}; b_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}; std::array<ck::index_t, 4> abc_lengths;
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}; std::array<ck::index_t, 4> a_strides;
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()}; std::array<ck::index_t, 4> b_strides;
std::array<ck::index_t, 4> c_strides;
std::copy(nchw.begin(), nchw.end(), abc_lengths.begin());
std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin());
std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin());
std::copy(c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end(), c_strides.begin());
auto broadcastAdd = DeviceElementwiseAddInstance{}; auto broadcastAdd = DeviceElementwiseAddInstance{};
auto argument = auto argument = broadcastAdd.MakeArgumentPointer(
broadcastAdd.MakeArgumentPointer(input, abc_lengths, {a_strides, b_strides}, {c_strides}, input, output, Add{});
output,
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
{{a_strides}, b_strides},
{c_strides},
Add{});
if(!broadcastAdd.IsSupportedArgument(argument.get())) if(!broadcastAdd.IsSupportedArgument(argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error(
"DeviceBinaryElementwise instance, exiting!"); "The runtime parameters seems not supported by the device instance, exiting!");
}; };
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
...@@ -116,11 +109,8 @@ int main() ...@@ -116,11 +109,8 @@ int main()
c_device_buf.FromDevice(c.mData.data()); c_device_buf.FromDevice(c.mData.data());
Tensor<CDataType> host_c(nchw); Tensor<CDataType> host_c(nchw);
host_elementwise4D<Tensor<ABDataType>, host_elementwise4D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
Tensor<ABDataType>, host_c, a, b, nchw, Add{});
Tensor<CDataType>,
EltwiseComputeDataType,
Add>(host_c, a, b, nchw, Add{});
pass &= pass &=
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3); ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -28,57 +28,64 @@ using F32 = float; ...@@ -28,57 +28,64 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16;
using BiasDataType = F32;
using D0DataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using ReduceDataType = F32; using R0DataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>; using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using LayerNormOutDataType = F16; using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32; using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; // Layout
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using ALayout = Row;
using CLayout = ck::tensor_layout::gemm::RowMajor; using BLayout = Col;
using D1Layout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ELayout = D1Layout;
using AElementOp = PassThrough;
using BElementOp = PassThrough; // Elementwise op
using CElementOp = ck::tensor_operation::element_wise::Relu; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using D0ElementOp = PassThrough; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using ReduceSumOp = ck::reduce::Add; using Square = ck::tensor_operation::element_wise::UnarySquare;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using AElementOp = PassThrough;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using CDEElementOp = AddReluAdd;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using QsElementOp = ck::Tuple<PassThrough, Square>;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using RsElementOp = ck::Tuple<Div, Div>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
// ReduceOp
using ReduceGlobalMemOps = using R0ThreadReduceOp = ck::reduce::Add;
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using R1ThreadReduceOp = ck::reduce::Add;
ck::InMemoryDataOperationEnum::AtomicAdd>; using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto GemmSpecialization = static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, EDataType,
GemmAccDataType, GemmAccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
...@@ -87,23 +94,18 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp ...@@ -87,23 +94,18 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::tensor_operation::device::Device5AryElementwise<CDataType, ck::Tuple<EDataType,
ReduceDataType, R0DataType,
ReduceDataType, R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType>, // x(gemm_out), mean, meansquare, gamma, beta
LayerNormOutDataType, ck::Tuple<LayerNormOutDataType>, // y
NormalizeComputeDataType, NormalizeFunctor,
NormalizeFunctor, 2,
2, 8, // MPerthread
8, ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta
8, // scalarPerVector: gemm_out ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out)
1, // scalarPerVector: reduce_mean
1, // scalarPerVector: reduce_mean_square
8, // scalarPerVector: Gamma
8, // scalarPerVector: Beta
8>; // scalarPerVector: LayerNorm_out
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
...@@ -124,41 +126,31 @@ auto f_host_tensor_descriptor2d = ...@@ -124,41 +126,31 @@ auto f_host_tensor_descriptor2d =
} }
}; };
template <typename CDataType,
typename ReduceDataType,
typename AccDataType,
typename BiasDataType,
typename D0DataType,
typename A_functor,
typename B_functor,
typename C_functor,
typename C1_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<BiasDataType>& bias_n, const Tensor<D0DataType>& bias_n,
const Tensor<D0DataType>& c1_m_n, const Tensor<D1DataType>& d1_m_n,
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
A_functor a_element_op, AElementOp a_element_op,
B_functor b_element_op, BElementOp b_element_op,
C_functor c_element_op, CDEElementOp cde_element_op,
C1_functor c1_element_op,
int M, int M,
int N) int N)
{ {
int StrideC = N; int StrideE = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N}; auto averageOpInst = Div{N};
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -166,38 +158,32 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -166,38 +158,32 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType acc = ck::type_convert<AccDataType>(c_m_n(m, n)) + auto acc = ck::type_convert<GemmAccDataType>(e_m_n(m, n));
ck::type_convert<AccDataType>(bias_n(n)); cde_element_op(e_m_n(m, n), acc, bias_n(n), d1_m_n(m, n));
AccDataType c1 = ck::type_convert<AccDataType>(c1_m_n(m, n));
c_element_op(acc, acc);
c1_element_op(c1, c1);
acc += c1;
c_m_n(m, n) = ck::type_convert<CDataType>(acc);
} }
// reduce_mean and reduce_square_mean // reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{}; auto r0Op = R0ThreadReduceOp{};
auto r1Op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>(); auto mean_acc = r0Op.GetIdentityValue<ReduceAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>(); auto mean_square_acc = r1Op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType c_val = ck::type_convert<AccDataType>(c_m_n(m, n)); auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n(m, n));
AccDataType square_c_val = 0; ReduceAccDataType square_e_val = 0;
UnarySquareElementOp{}(square_c_val, c_val); Square{}(square_e_val, e_val);
reduceSumOpInst(mean_acc, c_val); r0Op(mean_acc, e_val);
reduceSumOpInst(square_mean_acc, square_c_val); r1Op(mean_square_acc, square_e_val);
} }
averageOpInst(mean_acc, mean_acc); averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc); averageOpInst(mean_square_acc, mean_square_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc); mean_m(m) = ck::type_convert<R0DataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc); meanSquare_m(m) = ck::type_convert<R1DataType>(mean_square_acc);
} }
// LayerNorm // LayerNorm
...@@ -206,24 +192,20 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -206,24 +192,20 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType out_acc = 0; LayerNormOutDataType out_val = 0;
layerNormInst(out_acc, layerNormInst(out_val, e_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
ck::type_convert<AccDataType>(c_m_n(m, n)), out_m_n(m, n) = out_val;
ck::type_convert<AccDataType>(mean_m(m)),
ck::type_convert<AccDataType>(meanSquare_m(m)),
ck::type_convert<AccDataType>(gamma_n(n)),
ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
} }
} }
} }
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename BiasDataType,
typename D0DataType, typename D0DataType,
typename ReduceDataType, typename D1DataType,
typename R0DataType,
typename R1DataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename NormalizeDataType> typename NormalizeDataType>
...@@ -231,12 +213,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, ...@@ -231,12 +213,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N + sizeof(EDataType) * M * N + sizeof(D0DataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M + sizeof(D0DataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M; sizeof(R1DataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + std::size_t normalize_num_byte = sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + sizeof(R1DataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
...@@ -259,37 +241,37 @@ int main() ...@@ -259,37 +241,37 @@ int main()
ck::index_t StrideA = 1024; ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024; ck::index_t StrideD0 = 0;
ck::index_t StrideD0 = 1024; ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<D0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1)); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, ELayout{}));
Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> r0_Mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n( Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
bias_n.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1, 1}); bias_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-5, 5}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-5, 5});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1}); gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1}); beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize()); DeviceMem bias_device_buf(sizeof(D0DataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize());
reduceMean_m.mDesc.GetElementSpaceSize()); DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) *
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * r1_MeanSquare_m.mDesc.GetElementSpaceSize());
reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
...@@ -298,106 +280,96 @@ int main() ...@@ -298,106 +280,96 @@ int main()
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data()); bias_device_buf.ToDevice(bias_n.mData.data());
d0_device_buf.ToDevice(c1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data()); gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
auto d_element_op = D0ElementOp{}; auto qs_element_op = QsElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; auto rs_element_op = RsElementOp{N, N};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(), // Prepare GEMM, mean, mean_square
reduceMeanSquare_device_buf.GetDeviceBuffer()}; auto gemmReduce = DeviceOpInstance{};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker(); auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), auto gemmReduce_argument = gemmReduce.MakeArgument(
b_device_buf.GetDeviceBuffer(), a_device_buf.GetDeviceBuffer(),
bias_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer()}, {bias_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
c_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
p_reduces, {r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()},
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, {StrideD0, StrideD1},
{StrideD0}, StrideE,
gemm_element_ops, a_element_op,
{&d_element_op}, b_element_op,
reduce_in_element_ops, cde_element_op,
reduce_out_element_ops); qs_element_op,
rs_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{ {
throw std::runtime_error( throw std::runtime_error("wrong! this device_op instance does not support this problem");
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
} }
reduceMean_device_buf.SetZero(); // init reducetion buffer to 0
reduceMeanSquare_device_buf.SetZero(); r0_Mean_device_buf.SetZero();
r1_MeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(), std::array<const void*, 5> input = {e_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(), r0_Mean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()}; beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
auto normalize = DeviceNormalizeInstance{}; std::array<ck::index_t, 2> xyLengths = {M, N};
auto normalize_invoker = normalize.MakeInvoker(); std::array<ck::index_t, 2> xyStrides = {StrideE, 1};
auto normalize_argument = normalize.MakeArgument(input,
output, auto normalize = DeviceNormalizeInstance{};
{M, N}, auto normalize_invoker = normalize.MakeInvoker();
{StrideC, 1}, auto normalize_argument_ptr =
{1, 0}, normalize.MakeArgumentPointer(xyLengths,
{1, 0}, {xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{0, 1}, {xyStrides},
{0, 1}, input,
{StrideC, 1}, output,
NormalizeFunctor{}); NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument)) if(!normalize.IsSupportedArgument(normalize_argument_ptr.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error(
"Device5AryElementwise instance, exiting!"); "The runtime parameters seems not supported by the device, exiting!");
} }
// run kernel // run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
{ {
// verification // verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n( Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n, host_gemm_layernorm(host_layerNorm_m_n,
a_m_k, a_m_k,
b_k_n, b_k_n,
bias_n, bias_n,
c1_m_n, d1_m_n,
gamma_n, gamma_n,
beta_n, beta_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
d_element_op, M,
M, N);
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData, pass &= ck::utils::check_err(layerNorm_m_n.mData,
...@@ -414,15 +386,16 @@ int main() ...@@ -414,15 +386,16 @@ int main()
float gemm_reduce_mean_reduce_square_mean_ave_time = float gemm_reduce_mean_reduce_square_mean_ave_time =
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
float normalize_ave_time = float normalize_ave_time =
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel) if(time_kernel)
DumpGemmLayerNormPerf<ADataType, DumpGemmLayerNormPerf<ADataType,
BDataType, BDataType,
CDataType, EDataType,
BiasDataType,
D0DataType, D0DataType,
ReduceDataType, D1DataType,
R0DataType,
R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType>( LayerNormOutDataType>(
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -28,78 +28,83 @@ using F32 = float; ...@@ -28,78 +28,83 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
// DataType
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32; using GemmAccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ReduceAccDataType = F32; using ReduceAccDataType = F32;
using ReduceDataType = F32; using R0DataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>; using R1DataType = F32;
using RsDataType = ck::Tuple<R0DataType, R1DataType>;
using GammaDataType = F16; using GammaDataType = F16;
using BetaDataType = F16; using BetaDataType = F16;
using LayerNormOutDataType = F16; using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32; using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor; // Layout
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using ALayout = Row;
using CLayout = ck::tensor_layout::gemm::RowMajor; using BLayout = Col;
using D1Layout = Row;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using ELayout = D1Layout;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; // Elementwise op
using ReduceSumOp = ck::reduce::Add; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using Div = ck::tensor_operation::element_wise::UnaryDivide;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using BElementOp = PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using CDEElementOp = PassThrough;
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using QsElementOp = ck::Tuple<PassThrough, Square>;
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using RsElementOp = ck::Tuple<Div, Div>;
using ReduceGlobalMemOps = // ReduceOp
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, using R0ThreadReduceOp = ck::reduce::Add;
ck::InMemoryDataOperationEnum::AtomicAdd>; using R1ThreadReduceOp = ck::reduce::Add;
using RsThreadReduceOp = ck::Tuple<R0ThreadReduceOp, R1ThreadReduceOp>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto R0GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
static constexpr auto R1GlobalReduceOp = ck::InMemoryDataOperationEnum::AtomicAdd;
using RsGlobalReduceOp = ck::InMemoryDataOperationEnumSequence<R0GlobalReduceOp, R1GlobalReduceOp>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultipleR_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| ELayout| AData| BData| GemmAccData| CShuffle| DsData| EData| ReduceAccData| RsData| A| B| CDE| Qs| Rs| Thread| Global| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDRThreadTransfer| CDE| RThreadTransfer|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, EDataType,
GemmAccDataType, GemmAccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; PassThrough>;
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y // A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance = using DeviceNormalizeInstance = ck::tensor_operation::device::DeviceElementwise<
ck::tensor_operation::device::Device5AryElementwise<CDataType, ck::Tuple<EDataType,
ReduceDataType, R0DataType,
ReduceDataType, R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType>, // x(gemm_out), mean,
LayerNormOutDataType, // meansquare,
NormalizeComputeDataType, // gamma, beta
NormalizeFunctor, ck::Tuple<LayerNormOutDataType>, // y
2, NormalizeFunctor,
8, 2,
8, // scalarPerVector: gemm_out 8, // MPerthread
1, // scalarPerVector: reduce_mean ck::Sequence<8, 1, 1, 8, 8>, // scalarPerVector: x(gemm_out), mean, meansquare, gamma, beta
1, // scalarPerVector: reduce_mean_square ck::Sequence<8>>; // scalarPerVector: y(layerNorm_out)
8, // scalarPerVector: Gamma
8, // scalarPerVector: Beta
8>; // scalarPerVector: LayerNorm_out
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
...@@ -120,60 +125,53 @@ auto f_host_tensor_descriptor2d = ...@@ -120,60 +125,53 @@ auto f_host_tensor_descriptor2d =
} }
}; };
template <typename CDataType,
typename ReduceDataType,
typename A_functor,
typename B_functor,
typename C_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n, const Tensor<BDataType>& b_k_n,
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
A_functor a_element_op, AElementOp a_element_op,
B_functor b_element_op, BElementOp b_element_op,
C_functor c_element_op, CDEElementOp c_element_op,
int M, int M,
int N) int N)
{ {
using out_type = ck::remove_reference_t<decltype(out_m_n(0, 0))>; int StrideE = N;
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
int StrideC = N; Tensor<R0DataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<R1DataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1)); auto averageOpInst = Div{N};
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op); ref_gemm.MakeArgument(a_m_k, b_k_n, e_m_n, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// reduce_mean and reduce_square_mean // reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{}; auto r0Op = R0ThreadReduceOp{};
auto r1Op = R1ThreadReduceOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>(); auto mean_acc = r0Op.GetIdentityValue<ReduceAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>(); auto mean_square_acc = r1Op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n)); auto e_val = ck::type_convert<ReduceAccDataType>(e_m_n(m, n));
auto square_c_val = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>(); ReduceAccDataType square_e_val = 0;
Square{}(square_e_val, e_val);
UnarySquareElementOp{}(square_c_val, c_val);
reduceSumOpInst(mean_acc, c_val); r0Op(mean_acc, e_val);
reduceSumOpInst(square_mean_acc, square_c_val); r1Op(mean_square_acc, square_e_val);
} }
averageOpInst(mean_acc, mean_acc); averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc); averageOpInst(mean_square_acc, mean_square_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc); mean_m(m) = ck::type_convert<R0DataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc); meanSquare_m(m) = ck::type_convert<R1DataType>(mean_square_acc);
} }
// LayerNorm // LayerNorm
...@@ -182,22 +180,18 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -182,22 +180,18 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float out_f32 = 0; LayerNormOutDataType out_val = 0;
layerNormInst(out_f32, layerNormInst(out_val, e_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
static_cast<float>(c_m_n(m, n)), out_m_n(m, n) = out_val;
static_cast<float>(mean_m(m)),
static_cast<float>(meanSquare_m(m)),
static_cast<float>(gamma_n(n)),
static_cast<float>(beta_n(n)));
out_m_n(m, n) = static_cast<out_type>(out_f32);
} }
} }
} }
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename EDataType,
typename ReduceDataType, typename R0DataType,
typename R1DataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename NormalizeDataType> typename NormalizeDataType>
...@@ -205,11 +199,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, ...@@ -205,11 +199,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{ {
std::size_t gemm_flop = std::size_t(2) * M * N * K; std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M; sizeof(R1DataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M + std::size_t normalize_num_btye = sizeof(EDataType) * M * N + sizeof(R0DataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N + sizeof(R1DataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time; float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
...@@ -232,17 +226,17 @@ int main() ...@@ -232,17 +226,17 @@ int main()
ck::index_t StrideA = 1024; ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024; ck::index_t StrideE = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R0DataType> r0_Mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); Tensor<R1DataType> r1_MeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1)); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n( Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
...@@ -251,11 +245,10 @@ int main() ...@@ -251,11 +245,10 @@ int main()
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * DeviceMem r0_Mean_device_buf(sizeof(R0DataType) * r0_Mean_m.mDesc.GetElementSpaceSize());
reduceMean_m.mDesc.GetElementSpaceSize()); DeviceMem r1_MeanSquare_device_buf(sizeof(R1DataType) *
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * r1_MeanSquare_m.mDesc.GetElementSpaceSize());
reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
...@@ -266,40 +259,33 @@ int main() ...@@ -266,40 +259,33 @@ int main()
gamma_device_buf.ToDevice(gamma_n.mData.data()); gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data()); beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op}; auto qs_element_op = QsElementOp{};
auto rs_element_op = RsElementOp{N, N};
auto passthrough = UnaryIdenticElementOp{}; // Prepare GEMM, mean, mean_square
auto square = UnarySquareElementOp{}; auto gemmReduce = DeviceOpInstance{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer()};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker(); auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(), auto gemmReduce_argument = gemmReduce.MakeArgument(
b_device_buf.GetDeviceBuffer(), a_device_buf.GetDeviceBuffer(),
nullptr, b_device_buf.GetDeviceBuffer(),
{}, {},
c_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
p_reduces, {r0_Mean_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer()},
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, {},
{}, StrideE,
gemm_element_ops, a_element_op,
{}, b_element_op,
reduce_in_element_ops, cde_element_op,
reduce_out_element_ops); qs_element_op,
rs_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{ {
...@@ -308,56 +294,56 @@ int main() ...@@ -308,56 +294,56 @@ int main()
"not support this GEMM problem"); "not support this GEMM problem");
} }
reduceMean_device_buf.SetZero(); r0_Mean_device_buf.SetZero();
reduceMeanSquare_device_buf.SetZero(); r1_MeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(), std::array<const void*, 5> input = {e_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(), r0_Mean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(), r1_MeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer()}; beta_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
auto normalize = DeviceNormalizeInstance{}; std::array<ck::index_t, 2> xyLengths = {M, N};
auto normalize_invoker = normalize.MakeInvoker(); std::array<ck::index_t, 2> xyStrides = {StrideE, 1};
auto normalize_argument = normalize.MakeArgument(input,
output, auto normalize = DeviceNormalizeInstance{};
{M, N}, auto normalize_invoker = normalize.MakeInvoker();
{StrideC, 1}, auto normalize_argument_ptr =
{1, 0}, normalize.MakeArgumentPointer(xyLengths,
{1, 0}, {xyStrides, {1, 0}, {1, 0}, {0, 1}, {0, 1}},
{0, 1}, {xyStrides},
{0, 1}, input,
{StrideC, 1}, output,
NormalizeFunctor{}); NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument)) if(!normalize.IsSupportedArgument(normalize_argument_ptr.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error(
"Device5AryElementwise instance, exiting!"); "The runtime parameters seems not supported by the device, exiting");
} }
// run kernel // run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
{ {
// verification // verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n( Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); f_host_tensor_descriptor2d(M, N, StrideE, ELayout{}));
host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n, host_gemm_layernorm(host_layerNorm_m_n,
a_m_k, a_m_k,
b_k_n, b_k_n,
gamma_n, gamma_n,
beta_n, beta_n,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
M, M,
N); N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData, pass &= ck::utils::check_err(layerNorm_m_n.mData,
...@@ -374,13 +360,14 @@ int main() ...@@ -374,13 +360,14 @@ int main()
float gemm_reduce_mean_reduce_square_mean_ave_time = float gemm_reduce_mean_reduce_square_mean_ave_time =
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
float normalize_ave_time = float normalize_ave_time =
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel) if(time_kernel)
DumpGemmLayerNormPerf<ADataType, DumpGemmLayerNormPerf<ADataType,
BDataType, BDataType,
CDataType, EDataType,
ReduceDataType, R0DataType,
R1DataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
LayerNormOutDataType>( LayerNormOutDataType>(
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl
// clang-format off
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
const int M = 256;
const int N = 128;
const int K = 64;
const int stride_A = K;
const int stride_B = K;
const int batch_stride_A = M * K;
const int batch_stride_B = K * N;
const int G0 = 16;
const int G1 = 8;
const int batch_count = G0 * G1;
// output layout - [G0, M, G1, N]
const int stride_G0 = M * G1 * N;
const int stride_G1 = N;
const int stride_M = G1 * N;
const int stride_N = 1;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
// GEMM shape
ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
Tensor<ADataType> a_g_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
auto f_host_e_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_,
std::size_t M_,
std::size_t N_,
std::size_t stride_G0_,
std::size_t stride_G1_,
std::size_t stride_M_,
std::size_t stride_N_) {
return HostTensorDescriptor(
std::vector<std::size_t>({G0_, G1_, M_, N_}),
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
};
Tensor<EDataType> e_g0_g1_m_n_host_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<EDataType> e_g0_g1_m_n_device_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
batch_stride_A,
batch_stride_B,
batched_gemm_e_permute_desc,
batch_count,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
for(int g0 = 0; g0 < G0; g0++)
{
for(int g1 = 0; g1 < G1; g1++)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData,
e_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c");
}
return pass ? 0 : 1;
}
add_example_executable(example_gemm_bias_e_permute_xdl_fp16 gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 1;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 3;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 2 && NumDimN == 3 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M2_N3_K1::Argument;
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a, ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, k0)));
arg.b_element_op_(
v_b,
ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, n2, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, m0, m1, n0, n1, n2) = v_c;
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_M3_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t G0 = 1;
ck::index_t M0 = 4;
ck::index_t M1 = 256;
ck::index_t N0 = 4;
ck::index_t N1 = 16;
ck::index_t N2 = 32;
ck::index_t K0 = 256;
// A[M0, M1, M2, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{M0 * M1 * K0, M1 * K0, K0, 1};
// B[N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, N0, N1, N2, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{N0 * N1 * N2 * K0, N1 * N2 * K0, N2 * K0, K0, 1};
// D[N0, M0, N1, M1, N2]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2};
std::vector<ck::index_t> d_gs_ms_ns_strides{N0 * N1 * N2, 0, 0, N1 * N2, N2, 1};
// E[N0, M0, N1, M1, N2]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2};
std::vector<ck::index_t> e_gs_ms_ns_strides{
M0 * M1 * N0 * N1 * N2, N1 * M1 * N2, N2, M0 * N1 * M1 * N2, M1 * N2, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks,
b_gs_ns_ks,
c_gs_ms_ns_host_result,
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1)
{
for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n2)
{
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
d_gs_ms_ns(g0, m0, m1, n0, n1, n2));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 1;
static constexpr ck::index_t NumDimM = 3;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 1;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################################| NumDimG| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| Gemm| A| B| DE| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
template <ck::index_t NumDimG,
ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ck::enable_if_t<NumDimG == 1 && NumDimM == 3 && NumDimN == 2 && NumDimK == 1, bool> =
false>
struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::BaseOperator
{
// Argument
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public ck::tensor_operation::device::BaseInvoker
{
using Argument = ReferenceContraction_G1_M3_N2_K1::Argument;
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int k0 = 0; k0 < K0; ++k0)
{
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(
v_a,
ck::type_convert<const AccDataType>(arg.a_gs_ms_ks_(g0, m0, m1, m2, k0)));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_gs_ns_ks_(g0, n0, n1, k0)));
v_acc += v_a * v_b;
}
AccDataType v_c;
arg.cde_element_op_(v_c, v_acc);
arg.e_gs_ms_ns_(g0, m0, m1, m2, n0, n1) = v_c;
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
{
return true;
}
static auto MakeArgument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{
a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceContraction_G1_M3_N2_K1"
<< std::endl;
// clang-format on
return str.str();
}
};
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t G0 = 1;
ck::index_t M0 = 4;
ck::index_t M1 = 8;
ck::index_t M2 = 256;
ck::index_t N0 = 32;
ck::index_t N1 = 128;
ck::index_t K0 = 1024;
// A[M0, M1, M2, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, M0, M1, M2, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{M0 * M1 * M2 * K0, M1 * M2 * K0, M2 * K0, K0, 1};
// B[N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{N0 * N1 * K0, N1 * K0, K0, 1};
// D[M0, N0, M1, N1, M2]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, M0, M1, M2, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{N0 * N1, 0, 0, 0, N1, 1};
// E[M1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, M0, M1, M2, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
M0 * M1 * M2 * N1 * N0, N0 * M1 * N1, N1, M0 * N0 * M1 * N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin(),
e_gs_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceOpInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_gs_ms_ks,
b_gs_ns_ks,
c_gs_ms_ns_host_result,
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1)
{
for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m2)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
c_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
d_gs_ms_ns(g0, m0, m1, m2, n0, n1));
}
}
}
}
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
}
return 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasEPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::index_t M0 = 4;
ck::index_t M1 = 32;
ck::index_t M2 = 128;
ck::index_t N0 = 16;
ck::index_t N1 = 256;
// GEMM shape
ck::index_t M = M0 * M1 * M2;
ck::index_t N = N0 * N1;
ck::index_t K = 128;
ck::index_t stride_A = K;
ck::index_t stride_B = K;
#if 1
// E = [M0, N0, M1, N1, M2]
ck::index_t stride_E_M0 = N0 * M1 * N1 * M2;
ck::index_t stride_E_M1 = N1 * M2;
ck::index_t stride_E_M2 = 1;
ck::index_t stride_E_N0 = M1 * N1 * M2;
ck::index_t stride_E_N1 = M2;
// D = [0, N0, 0, N1, 0]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
#else
// D = [0, 0, 0, N0, N1]
ck::index_t stride_D_M0 = 0;
ck::index_t stride_D_M1 = 0;
ck::index_t stride_D_M2 = 0;
ck::index_t stride_D_N0 = N1;
ck::index_t stride_D_N1 = 1;
// E = [M0, M1, M2, N0, N1]
ck::index_t stride_E_M0 = M1 * M2 * N0 * N1;
ck::index_t stride_E_M1 = M2 * N0 * N1;
ck::index_t stride_E_M2 = N0 * N1;
ck::index_t stride_E_N0 = N1;
ck::index_t stride_E_N1 = 1;
#endif
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc{
M0, M1, M2, N0, N1, stride_D_M0, stride_D_M1, stride_D_M2, stride_D_N0, stride_D_N1};
const ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc{
M0, M1, M2, N0, N1, stride_E_M0, stride_E_M1, stride_E_M2, stride_E_N0, stride_E_N1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
auto f_host_de_tensor_descriptor =
[](ck::tensor_operation::device::DEGridDesc_M0_M1_M2_N0_N1 de_grid_desc) {
std::size_t m0 = de_grid_desc.M0_;
std::size_t m1 = de_grid_desc.M1_;
std::size_t m2 = de_grid_desc.M2_;
std::size_t n0 = de_grid_desc.N0_;
std::size_t n1 = de_grid_desc.N1_;
std::size_t stride_m0 = de_grid_desc.stride_M0_;
std::size_t stride_m1 = de_grid_desc.stride_M1_;
std::size_t stride_m2 = de_grid_desc.stride_M2_;
std::size_t stride_n0 = de_grid_desc.stride_N0_;
std::size_t stride_n1 = de_grid_desc.stride_N1_;
return HostTensorDescriptor(
std::vector<std::size_t>({m0, m1, m2, n0, n1}),
std::vector<std::size_t>({stride_m0, stride_m1, stride_m2, stride_n0, stride_n1}));
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
Tensor<DDataType> d_m0_m1_m2_n0_n1(f_host_de_tensor_descriptor(d_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_host_result(f_host_de_tensor_descriptor(e_grid_desc));
Tensor<EDataType> e_m0_m1_m2_n0_n1_device_result(f_host_de_tensor_descriptor(e_grid_desc));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m0_m1_m2_n0_n1: " << d_m0_m1_m2_n0_n1.mDesc << std::endl;
std::cout << "e_m0_m1_m2_n0_n1: " << e_m0_m1_m2_n0_n1_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) *
d_m0_m1_m2_n0_n1.mDesc.GetElementSpaceSize());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(
sizeof(EDataType) * e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
d_m0_m1_m2_n0_n1_device_buf.ToDevice(d_m0_m1_m2_n0_n1.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
b_k_n_device_buf.GetDeviceBuffer(),
d_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
e_m0_m1_m2_n0_n1_device_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
d_grid_desc,
e_grid_desc,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< device_op.GetTypeString() << std::endl;
if(do_verification)
{
Tensor<AccDataType> c_m_n(HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m0 = 0; m0 < M0; ++m0)
for(int m1 = 0; m1 < M1; ++m1)
for(int m2 = 0; m2 < M2; ++m2)
for(int n0 = 0; n0 < N0; ++n0)
for(int n1 = 0; n1 < N1; ++n1)
{
int m = m0 * M1 * M2 + m1 * M2 + m2;
int n = n0 * N1 + n1;
cde_element_op(e_m0_m1_m2_n0_n1_host_result(m0, m1, m2, n0, n1),
ck::type_convert<EDataType>(c_m_n(m, n)),
d_m0_m1_m2_n0_n1(m0, m1, m2, n0, n1));
}
e_m0_m1_m2_n0_n1_device_buf.FromDevice(e_m0_m1_m2_n0_n1_device_result.mData.data());
return ck::utils::check_err(e_m0_m1_m2_n0_n1_device_result.mData,
e_m0_m1_m2_n0_n1_host_result.mData)
? 0
: 1;
}
return 0;
}
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -29,24 +29,24 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,24 +29,24 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType, using DeviceInstance = ck::tensor_operation::device::DeviceLayernormImpl<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterM 8, // ClusterM
32, // ClusterK 32, // ClusterK
1, // SliceM 1, // SliceM
8, // SliceK 8, // SliceK
1, // SrcVecDim (0=M, 1=K) 1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector 8, // SrcScalarPerVector
8, // GammaScalarPerVector 8, // GammaScalarPerVector
8, // BetaScalarPerVector 8, // BetaScalarPerVector
1>; // OutScalarPerVector 8>; // OutScalarPerVector
int main() int main()
{ {
...@@ -90,6 +90,7 @@ int main() ...@@ -90,6 +90,7 @@ int main()
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()}, std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()}, std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
add_example_executable(example_grouped_gemm_bias_xdl_fp16 grouped_gemm_bias_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using DsLayout = ck::Tuple<DLayout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<std::array<const void*, 1>> p_ds;
std::vector<void*> p_c;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
int stride_A = K;
int stride_B = K;
int stride_C = N;
std::vector<ck::index_t> stride_Ds = {0};
gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, stride_Ds});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<DDataType>> d_tensors;
std::vector<Tensor<EDataType>> e_host_tensors;
std::vector<Tensor<EDataType>> e_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d_tensors.reserve(group_count);
e_host_tensors.reserve(group_count);
e_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d_tensors_device,
e_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d_tensors_device.reserve(group_count);
e_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
d_tensors.push_back(Tensor<DDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_Ds_[0], ELayout{})));
e_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
e_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << e_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
switch(init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
d_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_tensors[i].mDesc.GetElementSpaceSize()));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d_tensors_device[i]->ToDevice(d_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_ds.push_back({d_tensors_device[i]->GetDeviceBuffer()});
p_c.push_back(e_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, cde_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true;
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
e_host_tensors[i],
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < gemm_descs[i].M_; ++m)
{
for(int n = 0; n < gemm_descs[i].N_; ++n)
{
cde_element_op(
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(e_device_tensors[i].mData, e_host_tensors[i].mData);
}
}
return pass ? 0 : 1;
}
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment