"...composable_kernel_rocm.git" did not exist on "0ef27d537d855311099e0863a47b2564898fca7d"
Commit e4e99a49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new utilities to shorten codes

parent 7acbf104
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include <random> // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -28,8 +29,6 @@ struct ExecutionConfig final ...@@ -28,8 +29,6 @@ struct ExecutionConfig final
bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
using namespace ck::literals;
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(KernelADataType)); static_assert(sizeof(ADataType) == sizeof(KernelADataType));
...@@ -48,6 +47,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -48,6 +47,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
batch_stride_C, batch_stride_C,
batch_count] = problem_size; batch_count] = problem_size;
using namespace ck::literals;
// GEMM shape // GEMM shape
auto f_host_tensor_descriptor = [](std::size_t batch_count_, auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row, std::size_t row,
...@@ -55,15 +56,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -55,15 +56,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}), return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}), return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -79,9 +78,9 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -79,9 +78,9 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
#endif #endif
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.GetDesc() << std::endl;
std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl; std::cout << "e_g_m_n: " << e_g_m_n_device_result.GetDesc() << std::endl;
switch(config.init_method) switch(config.init_method)
{ {
...@@ -96,19 +95,19 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -96,19 +95,19 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_g_k_n.GetMemorySize());
DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(e_g_m_n_device_result.GetMemorySize());
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k); const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k);
const Tensor<KernelBDataType> b_g_k_n_converted(b_g_k_n); const Tensor<KernelBDataType> b_g_k_n_converted(b_g_k_n);
a_device_buf.ToDevice(a_g_m_k_converted.mData.data()); a_device_buf.ToDevice(a_g_m_k_converted.data());
b_device_buf.ToDevice(b_g_k_n_converted.mData.data()); b_device_buf.ToDevice(b_g_k_n_converted.data());
#else #else
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.data());
#endif #endif
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -150,7 +149,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -150,7 +149,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if(config.do_verification) if(config.do_verification)
{ {
c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(e_g_m_n_device_result.data());
using ReferenceBatchedGemmInstance = using ReferenceBatchedGemmInstance =
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
...@@ -174,11 +173,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -174,11 +173,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> e_device_result_converted(e_g_m_n_device_result); const Tensor<EDataType> e_device_result_converted(e_g_m_n_device_result);
pass &= ck::utils::check_err(e_device_result_converted.mData, e_g_m_n_host_result.mData); pass &= ck::utils::check_err(e_device_result_converted, e_g_m_n_host_result);
#else #else
pass = ck::utils::check_err( pass = ck::utils::check_err(
e_g_m_n_device_result.mData, e_g_m_n_host_result.mData, "Error: Incorrect results c"); e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c");
#endif #endif
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -110,7 +111,7 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B ...@@ -110,7 +111,7 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) { auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3]; const int K0 = arg.a_gs_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -136,12 +137,12 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B ...@@ -136,12 +137,12 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B
}; };
make_ParallelTensorFunctor(f_gs_ms_ns, make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0], arg.e_gs_ms_ns_.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1], arg.e_gs_ms_ns_.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2], arg.e_gs_ms_ns_.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3], arg.e_gs_ms_ns_.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4], arg.e_gs_ms_ns_.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( arg.e_gs_ms_ns_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -246,26 +247,16 @@ int main(int argc, char* argv[]) ...@@ -246,26 +247,16 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
Tensor<ADataType> a_gs_ms_ks( Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<BDataType> b_gs_ns_ks( Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns( std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.GetDesc() << std::endl;
Tensor<EDataType> e_gs_ms_ns_host_result( std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.GetDesc() << std::endl;
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -282,15 +273,14 @@ int main(int argc, char* argv[]) ...@@ -282,15 +273,14 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_gs_ms_ks.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_gs_ns_ks.GetMemorySize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(d_gs_ms_ns.GetMemorySize());
DeviceMem e_device_buf(sizeof(EDataType) * DeviceMem e_device_buf(e_gs_ms_ns_device_result.GetMemorySize());
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.data());
// set zero // set zero
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -299,19 +289,21 @@ int main(int argc, char* argv[]) ...@@ -299,19 +289,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
using ck::utils::to_array;
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
b_gs_ns_ks_strides, b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, to_array({d_gs_ms_ns_lengths}),
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, to_array({d_gs_ms_ns_strides}),
e_gs_ms_ns_lengths, e_gs_ms_ns_lengths,
e_gs_ms_ns_strides, e_gs_ms_ns_strides,
a_element_op, a_element_op,
...@@ -327,18 +319,18 @@ int main(int argc, char* argv[]) ...@@ -327,18 +319,18 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG, std::size_t M = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimM,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
std::size_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, std::size_t N = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN, NumDimN,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
std::size_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, std::size_t K = ck::accumulate_n(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK, NumDimK,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
...@@ -353,13 +345,11 @@ int main(int argc, char* argv[]) ...@@ -353,13 +345,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl; << op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); e_device_buf.FromDevice(e_gs_ms_ns_device_result.data());
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_gs_ms_ns_host_result( Tensor<CShuffleDataType> c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM, using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN, NumDimN,
...@@ -384,18 +374,17 @@ int main(int argc, char* argv[]) ...@@ -384,18 +374,17 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.GetLengths()[0]; ++g0)
{ {
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0) for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.GetLengths()[1]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1) for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.GetLengths()[2]; ++m1)
{ {
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0) for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.GetLengths()[3]; ++n0)
{ {
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1) for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.GetLengths()[4]; ++n1)
{ {
for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.GetLengths()[5]; ++n2)
++n2)
{ {
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2), c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
...@@ -407,9 +396,7 @@ int main(int argc, char* argv[]) ...@@ -407,9 +396,7 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
? 0
: 1;
} }
return 0; return 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -108,7 +110,7 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B ...@@ -108,7 +110,7 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) { auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; const int K0 = arg.a_gs_ms_ks_.GetLengths()[4];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -134,12 +136,12 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B ...@@ -134,12 +136,12 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B
}; };
make_ParallelTensorFunctor(f_gs_ms_ns, make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0], arg.e_gs_ms_ns_.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1], arg.e_gs_ms_ns_.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2], arg.e_gs_ms_ns_.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3], arg.e_gs_ms_ns_.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4], arg.e_gs_ms_ns_.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( arg.e_gs_ms_ns_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -246,26 +248,16 @@ int main(int argc, char* argv[]) ...@@ -246,26 +248,16 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
Tensor<ADataType> a_gs_ms_ks( Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<BDataType> b_gs_ns_ks( Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns( std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.GetDesc() << std::endl;
Tensor<EDataType> e_gs_ms_ns_host_result( std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.GetDesc() << std::endl;
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -282,15 +274,14 @@ int main(int argc, char* argv[]) ...@@ -282,15 +274,14 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_gs_ms_ks.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_gs_ns_ks.GetMemorySize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(d_gs_ms_ns.GetMemorySize());
DeviceMem e_device_buf(sizeof(EDataType) * DeviceMem e_device_buf(e_gs_ms_ns_device_result.GetMemorySize());
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.data());
// set zero // set zero
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -299,19 +290,21 @@ int main(int argc, char* argv[]) ...@@ -299,19 +290,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
using ck::utils::to_array;
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
b_gs_ns_ks_strides, b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, to_array({d_gs_ms_ns_lengths}),
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, to_array({d_gs_ms_ns_strides}),
e_gs_ms_ns_lengths, e_gs_ms_ns_lengths,
e_gs_ms_ns_strides, e_gs_ms_ns_strides,
a_element_op, a_element_op,
...@@ -327,18 +320,16 @@ int main(int argc, char* argv[]) ...@@ -327,18 +320,16 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin(), ck::index_t M = ck::accumulate_n(
e_gs_ms_ns_lengths.begin() + NumDimM, e_gs_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimM, ck::index_t N = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimM + NumDimN, NumDimN,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimM, ck::index_t K = ck::accumulate_n(a_gs_ms_ks_lengths.begin() + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimM + NumDimK, NumDimK,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
...@@ -353,13 +344,11 @@ int main(int argc, char* argv[]) ...@@ -353,13 +344,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl; << op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); e_device_buf.FromDevice(e_gs_ms_ns_device_result.data());
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_gs_ms_ns_host_result( Tensor<CShuffleDataType> c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG, using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
NumDimM, NumDimM,
...@@ -385,18 +374,17 @@ int main(int argc, char* argv[]) ...@@ -385,18 +374,17 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.GetLengths()[0]; ++g0)
{ {
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0) for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.GetLengths()[1]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1) for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.GetLengths()[2]; ++m1)
{ {
for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m2) for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.GetLengths()[3]; ++m2)
{ {
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.GetLengths()[4]; ++n0)
{ {
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.GetLengths()[5]; ++n1)
++n1)
{ {
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1), cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
c_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1), c_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
...@@ -408,9 +396,7 @@ int main(int argc, char* argv[]) ...@@ -408,9 +396,7 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
? 0
: 1;
} }
return 0; return 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -122,8 +124,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -122,8 +124,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; const int K0 = arg.a_ms_ks_.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; const int K1 = arg.a_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -151,10 +153,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -151,10 +153,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0], arg.e_ms_ns_.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1], arg.e_ms_ns_.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2], arg.e_ms_ns_.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])( arg.e_ms_ns_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -288,26 +290,16 @@ int main(int argc, char* argv[]) ...@@ -288,26 +290,16 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
Tensor<ADataType> a_ms_ks( Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<BDataType> b_ns_ks( Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()), Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> d_ms_ns( std::cout << "a_ms_ks: " << a_ms_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()), std::cout << "b_ns_ks: " << b_ns_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end())); std::cout << "d_ms_ns: " << d_ms_ns.GetDesc() << std::endl;
Tensor<EDataType> e_ms_ns_host_result( std::cout << "e_ms_ns: " << e_ms_ns_host_result.GetDesc() << std::endl;
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -324,14 +316,14 @@ int main(int argc, char* argv[]) ...@@ -324,14 +316,14 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_ms_ks.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_ns_ks.GetMemorySize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(d_ms_ns.GetMemorySize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(e_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_ms_ks.mData.data()); a_device_buf.ToDevice(a_ms_ks.data());
b_device_buf.ToDevice(b_ns_ks.mData.data()); b_device_buf.ToDevice(b_ns_ks.data());
d_device_buf.ToDevice(d_ms_ns.mData.data()); d_device_buf.ToDevice(d_ms_ns.data());
// set zero // set zero
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -340,19 +332,21 @@ int main(int argc, char* argv[]) ...@@ -340,19 +332,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
using ck::utils::to_array;
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, to_array({d_ms_ns_lengths}),
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, to_array({d_ms_ns_strides}),
e_ms_ns_lengths, e_ms_ns_lengths,
e_ms_ns_strides, e_ms_ns_strides,
a_element_op, a_element_op,
...@@ -368,20 +362,14 @@ int main(int argc, char* argv[]) ...@@ -368,20 +362,14 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), ck::index_t M = ck::accumulate_n(
e_ms_ns_lengths.begin() + NumDimM, e_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, ck::index_t N = ck::accumulate_n(
e_ms_ns_lengths.begin() + NumDimM + NumDimN, e_ms_ns_lengths.begin() + NumDimM, NumDimN, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, ck::index_t K = ck::accumulate_n(
a_ms_ks_lengths.begin() + NumDimM + NumDimK, a_ms_ks_lengths.begin() + NumDimM, NumDimK, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
...@@ -394,13 +382,11 @@ int main(int argc, char* argv[]) ...@@ -394,13 +382,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl; << op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); e_device_buf.FromDevice(e_ms_ns_device_result.data());
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result( Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -421,13 +407,13 @@ int main(int argc, char* argv[]) ...@@ -421,13 +407,13 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result.GetLengths()[1]; ++m1)
{ {
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) for(size_t n0 = 0; n0 < e_ms_ns_host_result.GetLengths()[2]; ++n0)
{ {
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) for(size_t n1 = 0; n1 < e_ms_ns_host_result.GetLengths()[3]; ++n1)
{ {
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1), c_ms_ns_host_result(m0, m1, n0, n1),
...@@ -437,7 +423,7 @@ int main(int argc, char* argv[]) ...@@ -437,7 +423,7 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
return 0; return 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -121,8 +123,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -121,8 +123,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; const int K0 = arg.a_ms_ks_.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; const int K1 = arg.a_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -150,10 +152,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -150,10 +152,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0], arg.e_ms_ns_.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1], arg.e_ms_ns_.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2], arg.e_ms_ns_.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])( arg.e_ms_ns_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -277,22 +279,14 @@ int main(int argc, char* argv[]) ...@@ -277,22 +279,14 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
Tensor<ADataType> a_ms_ks( Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<BDataType> b_ns_ks( Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end())); std::cout << "a_ms_ks: " << a_ms_ks.GetDesc() << std::endl;
Tensor<EDataType> e_ms_ns_host_result( std::cout << "b_ns_ks: " << b_ns_ks.GetDesc() << std::endl;
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()), std::cout << "e_ms_ns: " << e_ms_ns_host_result.GetDesc() << std::endl;
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -307,12 +301,12 @@ int main(int argc, char* argv[]) ...@@ -307,12 +301,12 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_ms_ks.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_ns_ks.GetMemorySize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(e_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_ms_ks.mData.data()); a_device_buf.ToDevice(a_ms_ks.data());
b_device_buf.ToDevice(b_ns_ks.mData.data()); b_device_buf.ToDevice(b_ns_ks.data());
// set zero // set zero
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -321,19 +315,21 @@ int main(int argc, char* argv[]) ...@@ -321,19 +315,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{scale}; auto cde_element_op = CDEElementOp{scale};
using ck::utils::empty_array;
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{}, empty_array(),
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 0>{}, empty_array(),
std::array<std::vector<ck::index_t>, 0>{}, empty_array(),
e_ms_ns_lengths, e_ms_ns_lengths,
e_ms_ns_strides, e_ms_ns_strides,
a_element_op, a_element_op,
...@@ -349,20 +345,14 @@ int main(int argc, char* argv[]) ...@@ -349,20 +345,14 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), ck::index_t M = ck::accumulate_n(
e_ms_ns_lengths.begin() + NumDimM, e_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, ck::index_t N = ck::accumulate_n(
e_ms_ns_lengths.begin() + NumDimM + NumDimN, e_ms_ns_lengths.begin() + NumDimM, NumDimN, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, ck::index_t K = ck::accumulate_n(
a_ms_ks_lengths.begin() + NumDimM + NumDimK, a_ms_ks_lengths.begin() + NumDimM, NumDimK, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -375,13 +365,11 @@ int main(int argc, char* argv[]) ...@@ -375,13 +365,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl; << op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); e_device_buf.FromDevice(e_ms_ns_device_result.data());
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result( Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -402,13 +390,13 @@ int main(int argc, char* argv[]) ...@@ -402,13 +390,13 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result.GetLengths()[1]; ++m1)
{ {
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) for(size_t n0 = 0; n0 < e_ms_ns_host_result.GetLengths()[2]; ++n0)
{ {
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) for(size_t n1 = 0; n1 < e_ms_ns_host_result.GetLengths()[3]; ++n1)
{ {
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1)); c_ms_ns_host_result(m0, m1, n0, n1));
...@@ -417,7 +405,7 @@ int main(int argc, char* argv[]) ...@@ -417,7 +405,7 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
return 0; return 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/ranges.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
...@@ -59,14 +62,14 @@ int main() ...@@ -59,14 +62,14 @@ int main()
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = N; ck::index_t Stride = N;
using namespace ck::literals;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor({len}, {stride});
std::vector<std::size_t>({stride}));
}; };
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {stride, 1_uz});
std::vector<std::size_t>({stride, 1}));
}; };
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
...@@ -78,22 +81,23 @@ int main() ...@@ -78,22 +81,23 @@ int main()
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0}); beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(x.GetMemorySize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(gamma.GetMemorySize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(beta.GetMemorySize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(y.GetMemorySize());
x_dev.ToDevice(x.data());
gamma_dev.ToDevice(gamma.data());
beta_dev.ToDevice(beta.data());
x_dev.ToDevice(x.mData.data()); using Indices = std::vector<ck::index_t>;
gamma_dev.ToDevice(gamma.mData.data());
beta_dev.ToDevice(beta.mData.data());
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer( auto argument_ptr = device_instance.MakeArgumentPointer({M, N},
{M, N}, ck::ranges::to<Indices>(x.GetStrides()),
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
{0, 1}, {0, 1},
{0, 1}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, ck::ranges::to<Indices>(y.GetStrides()),
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
...@@ -129,9 +133,8 @@ int main() ...@@ -129,9 +133,8 @@ int main()
auto ref_invoker = ref.MakeInvoker(); auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.data());
pass &= pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3);
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
} }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -104,7 +106,7 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base ...@@ -104,7 +106,7 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_ms_ns = [&](auto m0, auto m1, auto m2, auto n0, auto n1) { auto f_ms_ns = [&](auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[3]; const int K0 = arg.a_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -129,11 +131,11 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base ...@@ -129,11 +131,11 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0], arg.e_ms_ns_.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1], arg.e_ms_ns_.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2], arg.e_ms_ns_.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3], arg.e_ms_ns_.GetLengths()[3],
arg.e_ms_ns_.mDesc.GetLengths()[4])( arg.e_ms_ns_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -297,31 +299,21 @@ int main(int argc, char* argv[]) ...@@ -297,31 +299,21 @@ int main(int argc, char* argv[])
const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths;
const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides;
Tensor<ADataType> a_ms_ks( Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); Tensor<DDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<BDataType> b_ns_ks( Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end())); ck::index_t M_ = ck::accumulate_n(
Tensor<DDataType> d_ms_ns( e_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
ck::index_t M_ = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N_ = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, ck::index_t N_ = ck::accumulate_n(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN, NumDimN,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
ck::index_t K_ = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, ck::index_t K_ = ck::accumulate_n(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK, NumDimK,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
...@@ -334,13 +326,13 @@ int main(int argc, char* argv[]) ...@@ -334,13 +326,13 @@ int main(int argc, char* argv[])
flop += std::size_t(2) * M_ * K_ * N_; flop += std::size_t(2) * M_ * K_ * N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); sizeof(EDataType) * e_device_tensors[i].GetElementSize();
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].GetDesc()
<< " b_n_k: " << b_tensors[i].mDesc << " c_m_n: " << e_device_tensors[i].mDesc << " b_n_k: " << b_tensors[i].GetDesc()
<< std::endl; << " c_m_n: " << e_device_tensors[i].GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -364,18 +356,15 @@ int main(int argc, char* argv[]) ...@@ -364,18 +356,15 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < contraction_descs.size(); i++) for(std::size_t i = 0; i < contraction_descs.size(); i++)
{ {
a_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device.emplace_back(std::make_unique<DeviceMem>(a_tensors[i].GetMemorySize()));
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); b_tensors_device.emplace_back(std::make_unique<DeviceMem>(b_tensors[i].GetMemorySize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>( d_tensors_device.emplace_back(std::make_unique<DeviceMem>(d_tensors[i].GetMemorySize()));
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); e_tensors_device.emplace_back(
d_tensors_device.emplace_back(std::make_unique<DeviceMem>( std::make_unique<DeviceMem>(e_device_tensors[i].GetMemorySize()));
sizeof(DDataType) * d_tensors[i].mDesc.GetElementSpaceSize()));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device[i]->ToDevice(a_tensors[i].data());
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize())); b_tensors_device[i]->ToDevice(b_tensors[i].data());
d_tensors_device[i]->ToDevice(d_tensors[i].data());
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d_tensors_device[i]->ToDevice(d_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
...@@ -423,15 +412,11 @@ int main(int argc, char* argv[]) ...@@ -423,15 +412,11 @@ int main(int argc, char* argv[])
const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths;
const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides;
Tensor<EDataType> c_ms_ns_host_result( Tensor<EDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_host_result( Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); e_tensors_device[i]->FromDevice(e_device_tensors[i].data());
using ReferenceOpInstance = ReferenceContraction_M3_N2_K1<NumDimM, using ReferenceOpInstance = ReferenceContraction_M3_N2_K1<NumDimM,
NumDimN, NumDimN,
...@@ -456,15 +441,15 @@ int main(int argc, char* argv[]) ...@@ -456,15 +441,15 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result.GetLengths()[1]; ++m1)
{ {
for(size_t m2 = 0; m2 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++m2) for(size_t m2 = 0; m2 < e_ms_ns_host_result.GetLengths()[2]; ++m2)
{ {
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0) for(size_t n0 = 0; n0 < e_ms_ns_host_result.GetLengths()[3]; ++n0)
{ {
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1) for(size_t n1 = 0; n1 < e_ms_ns_host_result.GetLengths()[4]; ++n1)
{ {
cde_element_op(e_ms_ns_host_result(m0, m1, m2, n0, n1), cde_element_op(e_ms_ns_host_result(m0, m1, m2, n0, n1),
c_ms_ns_host_result(m0, m1, m2, n0, n1), c_ms_ns_host_result(m0, m1, m2, n0, n1),
...@@ -475,7 +460,7 @@ int main(int argc, char* argv[]) ...@@ -475,7 +460,7 @@ int main(int argc, char* argv[])
} }
} }
pass &= ck::utils::check_err(e_device_tensors[i].mData, e_ms_ns_host_result.mData); pass &= ck::utils::check_err(e_device_tensors[i], e_ms_ns_host_result);
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -109,7 +111,7 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B ...@@ -109,7 +111,7 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; const int K0 = arg.a_gs_ms_ks_.GetLengths()[4];
AccDataType v_acc = 0; AccDataType v_acc = 0;
...@@ -136,12 +138,12 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B ...@@ -136,12 +138,12 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0], arg.e_gs_ms_ns_.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1], arg.e_gs_ms_ns_.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2], arg.e_gs_ms_ns_.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3], arg.e_gs_ms_ns_.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4], arg.e_gs_ms_ns_.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( arg.e_gs_ms_ns_.GetLengths()[5])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -246,26 +248,16 @@ int main(int argc, char* argv[]) ...@@ -246,26 +248,16 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
Tensor<ADataType> a_gs_ms_ks( Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<BDataType> b_gs_ns_ks( Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns( std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.GetDesc() << std::endl;
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.GetDesc() << std::endl;
Tensor<EDataType> e_gs_ms_ns_host_result( std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.GetDesc() << std::endl;
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -282,15 +274,14 @@ int main(int argc, char* argv[]) ...@@ -282,15 +274,14 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a_gs_ms_ks.GetMemorySize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b_gs_ns_ks.GetMemorySize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(d_gs_ms_ns.GetMemorySize());
DeviceMem e_device_buf(sizeof(EDataType) * DeviceMem e_device_buf(e_gs_ms_ns_device_result.GetMemorySize());
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); a_device_buf.ToDevice(a_gs_ms_ks.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); b_device_buf.ToDevice(b_gs_ns_ks.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d_device_buf.ToDevice(d_gs_ms_ns.data());
// set zero // set zero
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -299,19 +290,21 @@ int main(int argc, char* argv[]) ...@@ -299,19 +290,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
using ck::utils::to_array;
// device operation // device operation
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(), e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
b_gs_ns_ks_strides, b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths}, to_array({d_gs_ms_ns_lengths}),
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides}, to_array({d_gs_ms_ns_strides}),
e_gs_ms_ns_lengths, e_gs_ms_ns_lengths,
e_gs_ms_ns_strides, e_gs_ms_ns_strides,
a_element_op, a_element_op,
...@@ -327,23 +320,21 @@ int main(int argc, char* argv[]) ...@@ -327,23 +320,21 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G = std::accumulate(e_gs_ms_ns_lengths.begin(), ck::index_t G = ck::accumulate_n(
e_gs_ms_ns_lengths.begin() + NumDimG, e_gs_ms_ns_lengths.begin(), NumDimG, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG, ck::index_t M = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimM,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, ck::index_t N = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN, NumDimN,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, ck::index_t K = ck::accumulate_n(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK, NumDimK,
ck::index_t{1}, ck::index_t{1},
std::multiplies<ck::index_t>{}); std::multiplies<ck::index_t>{});
...@@ -358,13 +349,11 @@ int main(int argc, char* argv[]) ...@@ -358,13 +349,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl; << op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); e_device_buf.FromDevice(e_gs_ms_ns_device_result.data());
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result( Tensor<CShuffleDataType> c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG, using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM, NumDimM,
...@@ -386,18 +375,17 @@ int main(int argc, char* argv[]) ...@@ -386,18 +375,17 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.GetLengths()[0]; ++g0)
{ {
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.GetLengths()[1]; ++g1)
{ {
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.GetLengths()[2]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.GetLengths()[3]; ++m1)
{ {
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.GetLengths()[4]; ++n0)
{ {
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.GetLengths()[5]; ++n1)
++n1)
{ {
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
...@@ -409,9 +397,7 @@ int main(int argc, char* argv[]) ...@@ -409,9 +397,7 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData) return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
? 0
: 1;
} }
return 0; return 0;
......
...@@ -9,32 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -9,32 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1 Gemm1
*/ */
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16; using ADataType = BF16;
using B0DataType = BF16; using B0DataType = BF16;
......
...@@ -13,29 +13,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -13,29 +13,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#error Should compile this file with ck::int4_t support #error Should compile this file with ck::int4_t support
#endif #endif
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t; using ADataType = ck::int4_t;
using B0DataType = ck::int4_t; using B0DataType = ck::int4_t;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment