Commit fdddc8f4 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/composable_kernel...

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/composable_kernel into barkocot/fix-cmake-tensor-op-instance
parents 8f48018d f7331c60
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
struct AddScale
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto I1 = ck::Number<1>{};
static constexpr auto I2 = ck::Number<2>{};
static constexpr auto I3 = ck::Number<3>{};
__host__ __device__ constexpr void
operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const
{
const auto a0_v_t = ck::vector_type<ck::half_t, 4>{a0};
const auto a1_v_t = ck::vector_type<ck::half_t, 4>{a1};
auto r_v_t = ck::vector_type<ck::half_t, 4>{};
r_v_t.AsType<ck::half_t>()(I0) =
scale * (a0_v_t.AsType<ck::half_t>()[I0] + a1_v_t.AsType<ck::half_t>()[I0]);
r_v_t.AsType<ck::half_t>()(I1) =
scale * (a0_v_t.AsType<ck::half_t>()[I1] + a1_v_t.AsType<ck::half_t>()[I1]);
r_v_t.AsType<ck::half_t>()(I2) =
scale * (a0_v_t.AsType<ck::half_t>()[I2] + a1_v_t.AsType<ck::half_t>()[I2]);
r_v_t.AsType<ck::half_t>()(I3) =
scale * (a0_v_t.AsType<ck::half_t>()[I3] + a1_v_t.AsType<ck::half_t>()[I3]);
a = r_v_t.AsType<ck::half4_t>()[I0];
}
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
{
a = scale * (a0 + a1);
}
// this attribute controls the copy_function applying element_wise_op with
// pack4_data
constexpr const static bool is_pack4_invocable = true;
float scale = 1.0;
};
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};
float alpha_;
float beta_;
};
using AElementOp = AddScale;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle<
ck::Tuple<ALayout, ALayout>,
ck::Tuple<BLayout>,
ck::Tuple<DLayout>,
ELayout,
ck::Tuple<ADataType, ADataType>,
ck::Tuple<BDataType>,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
float alpha = 1.0f;
float beta = 1.0f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]);
beta = std::stof(argv[5]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
alpha = std::stof(argv[11]);
beta = std::stof(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
"beta\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
a1_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
a1_device_buf.ToDevice(a1_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{0.2};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, 2>{StrideA, StrideA},
std::array<ck::index_t, 1>{StrideB},
std::array<ck::index_t, 1>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<ADataType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
PassThrough,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using A0DataType = F16;
using A1DataType = F32;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
struct AlphaBetaAdd
{
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){};
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, ck::half_t>(
ck::half_t& e, const float& c, const ck::half_t& d) const
{
e = ck::type_convert<ck::half_t>(alpha_ * c + beta_ * ck::type_convert<float>(d));
};
float alpha_;
float beta_;
};
struct Multiply
{
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const
{
a = ck::type_convert<ck::half_t>(ck::type_convert<float>(a0) * a1);
}
};
using AElementOp = Multiply;
using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle<
NumDimM,
NumDimN,
NumDimK,
ck::Tuple<A0DataType, A1DataType>,
ck::Tuple<BDataType>,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
float alpha = 1.0f;
float beta = 1.0f;
// A0[M0, M1, K0, K1]
std::vector<ck::index_t> a0_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1};
// A1[M1, K1] -> A1[M0, M1, K0, K1]
std::vector<ck::index_t> a1_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 0, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1};
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
exit(0);
}
Tensor<A0DataType> a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides);
Tensor<A1DataType> a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl;
std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_ms_ks.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
a1_ms_ks.GenerateTensorValue(GeneratorTensor_2<A1DataType>{-5, 5});
b_ns_ks.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a0_ms_ks.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a1_ms_ks.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b_ns_ks.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_ms_ks.mData.data());
a1_device_buf.ToDevice(a1_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
d_device_buf.ToDevice(d_ms_ns.mData.data());
// set zero
e_device_buf.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths},
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_strides, a1_ms_ks_strides},
std::array<std::vector<ck::index_t>, 1>{b_ns_ks_lengths},
std::array<std::vector<ck::index_t>, 1>{b_ns_ks_strides},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_contraction with the specified compilation parameters does "
"not support this problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
if(time_kernel)
{
ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
}
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<A0DataType> a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides);
for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1)
{
for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0)
{
for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1)
{
a_element_op(a_ms_ks(m0, m1, k0, k1),
a0_ms_ks(m0, m1, k0, k1),
a1_ms_ks(m0, m1, k0, k1));
}
}
}
}
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
A0DataType,
BDataType,
CShuffleDataType,
AccDataType,
PassThrough,
BElementOp>;
auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker();
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
auto ref_argument =
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op);
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
d_ms_ns(m0, m1, n0, n1));
}
}
}
}
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
return 0;
}
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_fwd_activ_xdl)
# Sigmoid
add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16)
# Tanh
add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16)
# Relu
add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16)
# SoftRelu
add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16)
# Abs
add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16)
# Pow
add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16)
# Clipped Relu
add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16)
# Leaky Relu
add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16)
# Elu
add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp)
add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
constexpr ck::index_t NDimSpatial = 3;
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using OutDataType = ck::half_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InLayout = ck::tensor_layout::convolution::GNDHWC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::GNDHWK;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename OutElementOp>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-1.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.05, 0.05});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
out_host,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
out_element_op);
ref_invoker.Run(ref_argument);
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!");
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Elu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Power;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Relu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::Sigmoid;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::SoftRelu;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_activ_common.hpp"
using OutElementOp = ck::tensor_operation::element_wise::TanH;
using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance<OutElementOp>;
#include "run_convnd_fwd_activ_example.inc"
int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
// Use floats for SoftRelu by default to avoid overflow after e^x.
int init_method =
std::is_same_v<OutElementOp, ck::tensor_operation::element_wise::SoftRelu> ? 2 : 1;
bool time_kernel = false;
// Following shapes are selected to avoid overflow. Expect inf in case of
// size increase for some elementwise ops.
ck::utils::conv::ConvParam conv_param{
3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
const auto run = [&]() {
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdActivInstance>(do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op,
out_element_op);
};
if(conv_param.num_dim_spatial_ == 3)
{
return run();
}
return false;
}
......@@ -58,6 +58,12 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable EXAMPLE_NAME)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(result EQUAL 0)
add_dependencies(${EXAMPLE_NAME} ${FILE_NAME})
endif()
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
set(result 1)
......
......@@ -66,6 +66,10 @@
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8_GFX11
#endif
// MFMA instruction
......
......@@ -3,8 +3,10 @@
#pragma once
#include <sstream>
#include <hip/hip_runtime.h>
// To be removed, which really does not tell the location of failed HIP functional call
inline void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
......@@ -15,3 +17,16 @@ inline void hip_check_error(hipError_t x)
throw std::runtime_error(ss.str());
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// GEMM:
// input : A0[M0, M1, ... K0, K1, ...], ...
// input : B0[N0, N1, ... K0, K1, ...], ...
// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
// output : E[M0, M1, ... N0, N1, ...]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceContractionMultipleABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
std::array<const void*, NumBTensor> p_bs,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
const std::vector<index_t>& e_ms_ns_length,
const std::vector<index_t>& e_ms_ns_stride,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -20,7 +20,8 @@ template <typename ALayout,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeType = CDataType>
struct DeviceGemmSplitK : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
......@@ -48,7 +49,8 @@ template <typename ALayout,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeType = CDataType>
using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
BLayout,
CLayout,
......@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
CElementwiseOperation,
ComputeType>>;
} // namespace device
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment