Unverified Commit ced5af16 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Extend support for contraction 6D (#1207)

* Extend support for contraction up to 5D

* Extend contraction bilinear instances

* Fix interface test

* Add 6d support, remove 3d,4d,5d

* Fixes

* Fix readme

* Make defualt dim for contraction instances
parent 366592b0
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -22,6 +22,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/host_utility/io.hpp"
......@@ -34,7 +35,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using F32 = float;
using F64 = double;
template <typename ALayout,
template <index_t NumDimMNK,
typename ALayout,
typename BLayout,
typename CDELayout,
typename DataType,
......@@ -104,18 +106,24 @@ int profile_contraction_impl(ck::index_t do_verification,
e_device_buf.SetZero();
d_device_buf.ToDevice(d_m_n.mData.data());
const std::vector<index_t> a_ms_ks_lengths = {M[0], M[1], K[0], K[1]};
const std::vector<index_t> b_ns_ks_lengths = {N[0], N[1], K[0], K[1]};
const std::vector<index_t> e_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
auto merge_dims = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23) {
std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
return dims_szt;
};
const std::vector<index_t> a_ms_ks_lengths = merge_dims(M, K);
const std::vector<index_t> b_ns_ks_lengths = merge_dims(N, K);
const std::vector<index_t> e_ms_ns_lengths = merge_dims(M, N);
const std::vector<index_t> d_m_n_lengths = merge_dims(M, N);
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
constexpr ck::index_t NumDim = 2;
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
NumDim,
NumDim,
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDimMNK,
NumDimMNK,
NumDimMNK,
DataType,
DataType,
DTupleDataType,
......@@ -138,9 +146,9 @@ int profile_contraction_impl(ck::index_t do_verification,
if(do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim,
NumDim,
NumDim,
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimMNK,
NumDimMNK,
NumDimMNK,
DataType,
DataType,
DataType,
......@@ -159,33 +167,20 @@ int profile_contraction_impl(ck::index_t do_verification,
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_m_n_host_result.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_m_n_host_result.mDesc.GetLengths()[1]; ++m1)
e_m_n_host_result.ForEach([&](auto& self, auto idx) {
if constexpr(is_same<CDElementOp, Bilinear>::value)
{
for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1)
{
if constexpr(is_same<CDElementOp, Bilinear>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1),
d_m_n(m0, m1, n0, n1));
}
else if constexpr(is_same<CDElementOp, Scale>::value)
{
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
c_m_n_host_result(m0, m1, n0, n1));
}
else
{
static_assert("Unsupported CDElementOp in contraction profiler.");
}
}
}
cde_element_op(self(idx), c_m_n_host_result(idx), d_m_n(idx));
}
}
else if constexpr(is_same<CDElementOp, Scale>::value)
{
cde_element_op(self(idx), c_m_n_host_result(idx));
}
else
{
static_assert("Unsupported CDElementOp in contraction profiler.");
}
});
}
std::string best_op_name;
......@@ -242,9 +237,12 @@ int profile_contraction_impl(ck::index_t do_verification,
auto invoker_ptr = op_ptr->MakeInvokerPointer();
auto nelems_m = M[0] * M[1];
auto nelems_n = N[0] * N[1];
auto nelems_k = K[0] * K[1];
auto nelems_m = ck::accumulate_n<ck::index_t>(
a_ms_ks_lengths.begin(), NumDimMNK, 1, std::multiplies<>{});
auto nelems_n = ck::accumulate_n<ck::index_t>(
b_ns_ks_lengths.begin(), NumDimMNK, 1, std::multiplies<>{});
auto nelems_k = ck::accumulate_n<ck::index_t>(
a_ms_ks_lengths.begin() + NumDimMNK, NumDimMNK, 1, std::multiplies<>{});
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -48,14 +48,36 @@ inline void collect_index_params(char* argv[],
// Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1}
// Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1}
// M1, 1, M0 * M1 * K1, M0 * M1
// K0, K1, M0, M1
inline void
assign_default_strides(Row, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1};
ck::index_t stride = 1;
for(ck::index_t s = strides.size() - 1; s >= 0; s--)
{
strides[s] = stride;
stride *= dims[s];
}
}
inline void
assign_default_strides(Col, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]};
// Assign second half of strides
ck::index_t stride = 1;
for(ck::index_t s = strides.size() / 2 - 1; s >= 0; s--)
{
strides[s] = stride;
stride *= dims[s];
}
// Assign first half of strides
for(ck::index_t s = strides.size() - 1; s > static_cast<ck::index_t>(strides.size()) / 2 - 1;
s--)
{
strides[s] = stride;
stride *= dims[s];
}
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
......@@ -19,7 +19,8 @@ static void print_helper_msg()
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
<< "arg4: Number of dimension for M, N and K (one for all)\n"
<< "arg5: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
......@@ -27,23 +28,23 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg5: verification (0: no; 1: yes)\n"
<< "arg6: initialization (0: no init; 1: integer value; 2: decimal "
<< "arg6: verification (0: no; 1: yes)\n"
<< "arg7: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< "arg9: alpha\n"
<< "arg10: beta\n"
<< "arg11 to 16: M0, M1, N0, N1, K0, K1\n"
<< "arg17 to 32: Strides for A, B, D and E (skip for default)\n"
<< "arg8: print tensor value (0: no; 1: yes)\n"
<< "arg9: time kernel (0: no, 1: yes)\n"
<< "arg10: alpha\n"
<< "arg11: beta\n"
<< "arg12 to 17/29: M0, M1, N0, N1, K0, K1\n"
<< "arg18/30 to 33/77: Strides for A, B, D and E (skip for default)\n"
<< std::endl;
}
int profile_contraction_bilinear(int argc, char* argv[])
{
const bool default_strides = argc == 17;
const bool default_strides = argc == 18 || 30;
if(argc != 33 && argc != 17)
if(argc != 34 && argc != 78 && !default_strides)
{
print_helper_msg();
exit(1);
......@@ -51,32 +52,33 @@ int profile_contraction_bilinear(int argc, char* argv[])
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const ck::index_t init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const float alpha = std::stof(argv[9]);
const float beta = std::stof(argv[10]);
const ck::index_t NumDimMNK = std::stoi(argv[4]);
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const ck::index_t init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
const bool time_kernel = std::stoi(argv[9]);
const float alpha = std::stof(argv[10]);
const float beta = std::stof(argv[11]);
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 11;
collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2);
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesE;
std::vector<ck::index_t> StridesD;
const ck::index_t dims_arg_num = 12;
collect_index_params(argv, M, dims_arg_num, NumDimMNK);
collect_index_params(argv, N, dims_arg_num + NumDimMNK, NumDimMNK);
collect_index_params(argv, K, dims_arg_num + NumDimMNK * 2, NumDimMNK);
std::vector<ck::index_t> StridesA(NumDimMNK * 2);
std::vector<ck::index_t> StridesB(NumDimMNK * 2);
std::vector<ck::index_t> StridesE(NumDimMNK * 2);
std::vector<ck::index_t> StridesD(NumDimMNK * 2);
if(!default_strides)
{
collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
collect_index_params(argv, StridesA, dims_arg_num + NumDimMNK * 3, NumDimMNK * 2);
collect_index_params(argv, StridesB, dims_arg_num + NumDimMNK * 5, NumDimMNK * 2);
collect_index_params(argv, StridesE, dims_arg_num + NumDimMNK * 7, NumDimMNK * 2);
collect_index_params(argv, StridesD, dims_arg_num + NumDimMNK * 9, NumDimMNK * 2);
}
using F16 = ck::half_t;
......@@ -95,31 +97,71 @@ int profile_contraction_bilinear(int argc, char* argv[])
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
auto merge_dims = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23) {
std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
return dims_szt;
};
assign_default_strides(a_layout, StridesA, merge_dims(M, K));
assign_default_strides(b_layout, StridesB, merge_dims(N, K));
assign_default_strides(cde_layout, StridesE, merge_dims(M, N));
assign_default_strides(cde_layout, StridesD, merge_dims(M, N));
}
if(NumDimMNK == 2)
{
bool pass = ck::profiler::profile_contraction_impl<2,
ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
}
else if(NumDimMNK == 6)
{
bool pass = ck::profiler::profile_contraction_impl<6,
ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
}
else
{
throw std::runtime_error("Not supported NumDimMNK");
return false;
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<DataType>,
Bilinear>(do_verification,
init_method,
do_log,
time_kernel,
Bilinear{alpha, beta},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
......@@ -19,7 +19,8 @@ static void print_helper_msg()
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n"
<< "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
<< "arg4: Number of dimension for M, N and K (one for all)\n"
<< "arg5: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
......@@ -27,22 +28,22 @@ static void print_helper_msg()
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
<< "arg5: verification (0: no; 1: yes)\n"
<< "arg6: initialization (0: no init; 1: integer value; 2: decimal "
<< "arg6: verification (0: no; 1: yes)\n"
<< "arg7: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< "arg9: alpha\n"
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
<< "arg8: print tensor value (0: no; 1: yes)\n"
<< "arg9: time kernel (0: no, 1: yes)\n"
<< "arg10: alpha\n"
<< "arg11 to 16/28: M0, M1, N0, N1, K0, K1\n"
<< "arg17/29 to 32/63: Strides for A, B, E (skip for default)\n"
<< std::endl;
}
int profile_contraction_scale(int argc, char* argv[])
{
const bool default_strides = argc == 16;
const bool default_strides = argc == 17 || argc == 29;
if(argc != 32 && argc != 16)
if(argc != 29 && argc != 65 && !default_strides)
{
print_helper_msg();
exit(1);
......@@ -50,31 +51,30 @@ int profile_contraction_scale(int argc, char* argv[])
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto compute_data_type = static_cast<ContractionComputeDataType>(std::stoi(argv[3]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[4]));
const bool do_verification = std::stoi(argv[5]);
const ck::index_t init_method = std::stoi(argv[6]);
const bool do_log = std::stoi(argv[7]);
const bool time_kernel = std::stoi(argv[8]);
const float alpha = std::stof(argv[9]);
const ck::index_t NumDimMNK = std::stoi(argv[4]);
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const ck::index_t init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
const bool time_kernel = std::stoi(argv[9]);
const float alpha = std::stof(argv[10]);
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 10;
collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2);
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesE;
std::vector<ck::index_t> StridesD;
const ck::index_t dims_arg_num = 11;
collect_index_params(argv, M, dims_arg_num, NumDimMNK);
collect_index_params(argv, N, dims_arg_num + NumDimMNK, NumDimMNK);
collect_index_params(argv, K, dims_arg_num + NumDimMNK * 2, NumDimMNK);
std::vector<ck::index_t> StridesA(NumDimMNK * 2);
std::vector<ck::index_t> StridesB(NumDimMNK * 2);
std::vector<ck::index_t> StridesE(NumDimMNK * 2);
if(!default_strides)
{
collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
collect_index_params(argv, StridesA, dims_arg_num + NumDimMNK * 3, NumDimMNK * 2);
collect_index_params(argv, StridesB, dims_arg_num + NumDimMNK * 5, NumDimMNK * 2);
collect_index_params(argv, StridesE, dims_arg_num + NumDimMNK * 7, NumDimMNK * 2);
}
using F16 = ck::half_t;
......@@ -93,32 +93,71 @@ int profile_contraction_scale(int argc, char* argv[])
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
auto merge_dims = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23) {
std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
return dims_szt;
};
assign_default_strides(a_layout, StridesA, merge_dims(M, K));
assign_default_strides(b_layout, StridesB, merge_dims(N, K));
assign_default_strides(cde_layout, StridesE, merge_dims(M, N));
}
bool pass = ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<>,
Scale>(do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesD);
return pass;
if(NumDimMNK == 2)
{
bool pass = ck::profiler::profile_contraction_impl<2,
ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<>,
Scale>(do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesE);
return pass;
}
else if(NumDimMNK == 6)
{
bool pass = ck::profiler::profile_contraction_impl<6,
ALayout,
BLayout,
CDELayout,
DataType,
ComputeDataType,
ck::Tuple<>,
Scale>(do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesE,
StridesE);
return pass;
}
else
{
throw std::runtime_error("Not supported NumDimMNK");
return false;
}
};
auto run_profile_for_datatype = [&](auto type, auto compute_type) {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdexcept>
#include <vector>
......@@ -125,18 +125,6 @@ class ContractionDeviceOpWrapper
}
};
TEST(TestContractionInterface, IncorrectNumDims)
{
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
}
TEST(TestContractionInterface, IncorrectDataTypes)
{
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......@@ -23,8 +23,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
template <ck::index_t NDims>
struct Dimensions
{
constexpr static ck::index_t NumDimMNK = NDims;
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
......@@ -42,53 +45,58 @@ class TestContraction : public ::testing::Test
using ComputeDataType = std::tuple_element_t<5, Tuple>;
using CDElementOp = std::tuple_element_t<6, Tuple>;
std::vector<Dimensions> dimension_list = {{{32, 32}, {32, 32}, {32, 32}},
{{16, 16}, {32, 32}, {16, 16}}};
std::vector<ck::index_t> init_methods = {1, 2};
std::unique_ptr<CDElementOp> p_cd_element_op;
void Run()
template <ck::index_t NumDim>
void Run(Dimensions<NumDim> dimension_params)
{
for(auto& dimension_params : dimension_list)
constexpr ck::index_t NumDimMNK = ck::remove_cvref_t<decltype(dimension_params)>::NumDimMNK;
std::vector<ck::index_t> StridesA(2 * NumDim);
std::vector<ck::index_t> StridesB(2 * NumDim);
std::vector<ck::index_t> StridesC(2 * NumDim);
std::vector<ck::index_t> StridesD(2 * NumDim);
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
auto merge_dims = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23) {
std::vector<ck::index_t> dims_szt(dims01.begin(), dims01.end());
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
return dims_szt;
};
assign_default_strides(ALayout{}, StridesA, merge_dims(M, K));
assign_default_strides(BLayout{}, StridesB, merge_dims(N, K));
assign_default_strides(CDLayout{}, StridesC, merge_dims(M, N));
assign_default_strides(CDLayout{}, StridesD, merge_dims(M, N));
for(const ck::index_t init_method : init_methods)
{
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
const auto& M = dimension_params.M;
const auto& N = dimension_params.N;
const auto& K = dimension_params.K;
assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]});
assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]});
for(const ck::index_t init_method : init_methods)
{
bool pass =
ck::profiler::profile_contraction_impl<ALayout,
BLayout,
CDLayout,
DataType,
ComputeDataType,
DTupleDataType,
CDElementOp>(true /*do_verification*/,
init_method,
false /*do_logs*/,
false /*time_kernel*/,
*p_cd_element_op,
dimension_params.M,
dimension_params.N,
dimension_params.K,
StridesA,
StridesB,
StridesC,
StridesD);
EXPECT_TRUE(pass);
}
bool pass =
ck::profiler::profile_contraction_impl<NumDimMNK,
ALayout,
BLayout,
CDLayout,
DataType,
ComputeDataType,
DTupleDataType,
CDElementOp>(true /*do_verification*/,
init_method,
false /*do_logs*/,
false /*time_kernel*/,
*p_cd_element_op,
dimension_params.M,
dimension_params.N,
dimension_params.K,
StridesA,
StridesB,
StridesC,
StridesD);
EXPECT_TRUE(pass);
}
}
};
......@@ -122,17 +130,31 @@ TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
TYPED_TEST(TestContractionBilinear, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
TYPED_TEST(TestContractionScale, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
template <typename Tuple>
......@@ -165,15 +187,29 @@ TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecis
TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear)
{
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
TYPED_TEST(TestContractionScaleMixedPrecision, scale)
{
this->p_cd_element_op = std::make_unique<Scale>(1.f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run();
this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}});
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment