Commit e4e99a49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new utilities to shorten codes

parent 7acbf104
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <random>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include <random>
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -28,8 +29,6 @@ struct ExecutionConfig final
bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(KernelADataType));
......@@ -48,6 +47,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
batch_stride_C,
batch_count] = problem_size;
using namespace ck::literals;
// GEMM shape
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
......@@ -55,15 +56,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride});
}
};
......@@ -79,9 +78,9 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
#endif
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl;
std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.GetDesc() << std::endl;
std::cout << "e_g_m_n: " << e_g_m_n_device_result.GetDesc() << std::endl;
switch(config.init_method)
{
......@@ -96,19 +95,19 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b_device_buf(b_g_k_n.GetMemorySize());
DeviceMem c_device_buf(e_g_m_n_device_result.GetMemorySize());
#ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k);
const Tensor<KernelBDataType> b_g_k_n_converted(b_g_k_n);
a_device_buf.ToDevice(a_g_m_k_converted.mData.data());
b_device_buf.ToDevice(b_g_k_n_converted.mData.data());
a_device_buf.ToDevice(a_g_m_k_converted.data());
b_device_buf.ToDevice(b_g_k_n_converted.data());
#else
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
a_device_buf.ToDevice(a_g_m_k.data());
b_device_buf.ToDevice(b_g_k_n.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -150,7 +149,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if(config.do_verification)
{
c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
c_device_buf.FromDevice(e_g_m_n_device_result.data());
using ReferenceBatchedGemmInstance =
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......@@ -174,11 +173,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> e_device_result_converted(e_g_m_n_device_result);
pass &= ck::utils::check_err(e_device_result_converted.mData, e_g_m_n_host_result.mData);
pass &= ck::utils::check_err(e_device_result_converted, e_g_m_n_host_result);
#else
pass = ck::utils::check_err(
e_g_m_n_device_result.mData, e_g_m_n_host_result.mData, "Error: Incorrect results c");
e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c");
#endif
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -110,7 +111,7 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto n0, auto n1, auto n2) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[3];
const int K0 = arg.a_gs_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0;
......@@ -136,12 +137,12 @@ struct ReferenceContraction_G1_M2_N3_K1 : public ck::tensor_operation::device::B
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
arg.e_gs_ms_ns_.GetLengths()[0],
arg.e_gs_ms_ns_.GetLengths()[1],
arg.e_gs_ms_ns_.GetLengths()[2],
arg.e_gs_ms_ns_.GetLengths()[3],
arg.e_gs_ms_ns_.GetLengths()[4],
arg.e_gs_ms_ns_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......@@ -246,26 +247,16 @@ int main(int argc, char* argv[])
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.GetDesc() << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.GetDesc() << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.GetDesc() << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -282,15 +273,14 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a_gs_ms_ks.GetMemorySize());
DeviceMem b_device_buf(b_gs_ns_ks.GetMemorySize());
DeviceMem d_device_buf(d_gs_ms_ns.GetMemorySize());
DeviceMem e_device_buf(e_gs_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
a_device_buf.ToDevice(a_gs_ms_ks.data());
b_device_buf.ToDevice(b_gs_ns_ks.data());
d_device_buf.ToDevice(d_gs_ms_ns.data());
// set zero
e_device_buf.SetZero();
......@@ -299,19 +289,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
using ck::utils::to_array;
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
to_array({d_gs_ms_ns_lengths}),
to_array({d_gs_ms_ns_strides}),
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
......@@ -327,20 +319,20 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t M = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG,
NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t N = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t K = ck::accumulate_n(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
......@@ -353,13 +345,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
e_device_buf.FromDevice(e_gs_ms_ns_device_result.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1<NumDimM,
NumDimN,
......@@ -384,18 +374,17 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.GetLengths()[0]; ++g0)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0)
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.GetLengths()[1]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1)
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.GetLengths()[2]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0)
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.GetLengths()[3]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1)
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.GetLengths()[4]; ++n1)
{
for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n2)
for(size_t n2 = 0; n2 < e_gs_ms_ns_host_result.GetLengths()[5]; ++n2)
{
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
c_gs_ms_ns_host_result(g0, m0, m1, n0, n1, n2),
......@@ -407,9 +396,7 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
}
return 0;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -108,7 +110,7 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B
float Run(const Argument& arg)
{
auto f_gs_ms_ns = [&](auto g0, auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
const int K0 = arg.a_gs_ms_ks_.GetLengths()[4];
AccDataType v_acc = 0;
......@@ -134,12 +136,12 @@ struct ReferenceContraction_G1_M3_N2_K1 : public ck::tensor_operation::device::B
};
make_ParallelTensorFunctor(f_gs_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
arg.e_gs_ms_ns_.GetLengths()[0],
arg.e_gs_ms_ns_.GetLengths()[1],
arg.e_gs_ms_ns_.GetLengths()[2],
arg.e_gs_ms_ns_.GetLengths()[3],
arg.e_gs_ms_ns_.GetLengths()[4],
arg.e_gs_ms_ns_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......@@ -246,26 +248,16 @@ int main(int argc, char* argv[])
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.GetDesc() << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.GetDesc() << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.GetDesc() << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -282,15 +274,14 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a_gs_ms_ks.GetMemorySize());
DeviceMem b_device_buf(b_gs_ns_ks.GetMemorySize());
DeviceMem d_device_buf(d_gs_ms_ns.GetMemorySize());
DeviceMem e_device_buf(e_gs_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
a_device_buf.ToDevice(a_gs_ms_ks.data());
b_device_buf.ToDevice(b_gs_ns_ks.data());
d_device_buf.ToDevice(d_gs_ms_ns.data());
// set zero
e_device_buf.SetZero();
......@@ -299,19 +290,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
using ck::utils::to_array;
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
to_array({d_gs_ms_ns_lengths}),
to_array({d_gs_ms_ns_strides}),
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
......@@ -327,20 +320,18 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin(),
e_gs_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = ck::accumulate_n(
e_gs_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimM,
NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = ck::accumulate_n(a_gs_ms_ks_lengths.begin() + NumDimM,
NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
......@@ -353,13 +344,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
e_device_buf.FromDevice(e_gs_ms_ns_device_result.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<CShuffleDataType> c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1<NumDimG,
NumDimM,
......@@ -385,18 +374,17 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.GetLengths()[0]; ++g0)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++m0)
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.GetLengths()[1]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m1)
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.GetLengths()[2]; ++m1)
{
for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m2)
for(size_t m2 = 0; m2 < e_gs_ms_ns_host_result.GetLengths()[3]; ++m2)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.GetLengths()[5]; ++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
c_gs_ms_ns_host_result(g0, m0, m1, m2, n0, n1),
......@@ -408,9 +396,7 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
}
return 0;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -122,8 +124,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
const int K0 = arg.a_ms_ks_.GetLengths()[2];
const int K1 = arg.a_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0;
......@@ -151,10 +153,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
arg.e_ms_ns_.GetLengths()[0],
arg.e_ms_ns_.GetLengths()[1],
arg.e_ms_ns_.GetLengths()[2],
arg.e_ms_ns_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......@@ -288,26 +290,16 @@ int main(int argc, char* argv[])
exit(0);
}
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
Tensor<BDataType> b_ns_ks(
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> d_ms_ns(
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a_ms_ks: " << a_ms_ks.GetDesc() << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.GetDesc() << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.GetDesc() << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -324,14 +316,14 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a_ms_ks.GetMemorySize());
DeviceMem b_device_buf(b_ns_ks.GetMemorySize());
DeviceMem d_device_buf(d_ms_ns.GetMemorySize());
DeviceMem e_device_buf(e_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
d_device_buf.ToDevice(d_ms_ns.mData.data());
a_device_buf.ToDevice(a_ms_ks.data());
b_device_buf.ToDevice(b_ns_ks.data());
d_device_buf.ToDevice(d_ms_ns.data());
// set zero
e_device_buf.SetZero();
......@@ -340,19 +332,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
using ck::utils::to_array;
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
to_array({d_ms_ns_lengths}),
to_array({d_ms_ns_strides}),
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
......@@ -368,20 +362,14 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = ck::accumulate_n(
e_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = ck::accumulate_n(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = ck::accumulate_n(
a_ms_ks_lengths.begin() + NumDimM, NumDimK, ck::index_t{1}, std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
......@@ -394,13 +382,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
e_device_buf.FromDevice(e_ms_ns_device_result.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
......@@ -421,13 +407,13 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
for(size_t m0 = 0; m0 < e_ms_ns_host_result.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
for(size_t m1 = 0; m1 < e_ms_ns_host_result.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
for(size_t n0 = 0; n0 < e_ms_ns_host_result.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
for(size_t n1 = 0; n1 < e_ms_ns_host_result.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1),
......@@ -437,7 +423,7 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1;
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
return 0;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -121,8 +123,8 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
const int K0 = arg.a_ms_ks_.GetLengths()[2];
const int K1 = arg.a_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0;
......@@ -150,10 +152,10 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3])(
arg.e_ms_ns_.GetLengths()[0],
arg.e_ms_ns_.GetLengths()[1],
arg.e_ms_ns_.GetLengths()[2],
arg.e_ms_ns_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......@@ -277,22 +279,14 @@ int main(int argc, char* argv[])
exit(0);
}
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
Tensor<BDataType> b_ns_ks(
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl;
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a_ms_ks: " << a_ms_ks.GetDesc() << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.GetDesc() << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -307,12 +301,12 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a_ms_ks.GetMemorySize());
DeviceMem b_device_buf(b_ns_ks.GetMemorySize());
DeviceMem e_device_buf(e_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
a_device_buf.ToDevice(a_ms_ks.data());
b_device_buf.ToDevice(b_ns_ks.data());
// set zero
e_device_buf.SetZero();
......@@ -321,19 +315,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{scale};
using ck::utils::empty_array;
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
empty_array(),
e_device_buf.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
empty_array(),
empty_array(),
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
......@@ -349,20 +345,14 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = ck::accumulate_n(
e_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = ck::accumulate_n(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = ck::accumulate_n(
a_ms_ks_lengths.begin() + NumDimM, NumDimK, ck::index_t{1}, std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......@@ -375,13 +365,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_ms_ns_device_result.mData.data());
e_device_buf.FromDevice(e_ms_ns_device_result.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
......@@ -402,13 +390,13 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
for(size_t m0 = 0; m0 < e_ms_ns_host_result.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
for(size_t m1 = 0; m1 < e_ms_ns_host_result.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0)
for(size_t n0 = 0; n0 < e_ms_ns_host_result.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1)
for(size_t n1 = 0; n1 < e_ms_ns_host_result.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1),
c_ms_ns_host_result(m0, m1, n0, n1));
......@@ -417,7 +405,7 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1;
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
}
return 0;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/ranges.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
......@@ -59,14 +62,14 @@ int main()
ck::index_t N = 1024;
ck::index_t Stride = N;
using namespace ck::literals;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
return HostTensorDescriptor({len}, {stride});
};
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
return HostTensorDescriptor({row, col}, {stride, 1_uz});
};
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
......@@ -78,29 +81,30 @@ int main()
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
DeviceMem x_dev(x.GetMemorySize());
DeviceMem gamma_dev(gamma.GetMemorySize());
DeviceMem beta_dev(beta.GetMemorySize());
DeviceMem y_dev(y.GetMemorySize());
x_dev.ToDevice(x.data());
gamma_dev.ToDevice(gamma.data());
beta_dev.ToDevice(beta.data());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
beta_dev.ToDevice(beta.mData.data());
using Indices = std::vector<ck::index_t>;
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer(
{M, N},
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
{0, 1},
{0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
{1},
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
PassThrough{});
auto argument_ptr = device_instance.MakeArgumentPointer({M, N},
ck::ranges::to<Indices>(x.GetStrides()),
{0, 1},
{0, 1},
ck::ranges::to<Indices>(y.GetStrides()),
{1},
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
......@@ -129,9 +133,8 @@ int main()
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data());
pass &=
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3);
y_dev.FromDevice(y.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3);
}
return (pass ? 0 : 1);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -104,7 +106,7 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto m0, auto m1, auto m2, auto n0, auto n1) {
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[3];
const int K0 = arg.a_ms_ks_.GetLengths()[3];
AccDataType v_acc = 0;
......@@ -129,11 +131,11 @@ struct ReferenceContraction_M3_N2_K1 : public ck::tensor_operation::device::Base
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_ms_ns_.mDesc.GetLengths()[0],
arg.e_ms_ns_.mDesc.GetLengths()[1],
arg.e_ms_ns_.mDesc.GetLengths()[2],
arg.e_ms_ns_.mDesc.GetLengths()[3],
arg.e_ms_ns_.mDesc.GetLengths()[4])(
arg.e_ms_ns_.GetLengths()[0],
arg.e_ms_ns_.GetLengths()[1],
arg.e_ms_ns_.GetLengths()[2],
arg.e_ms_ns_.GetLengths()[3],
arg.e_ms_ns_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
......@@ -297,33 +299,23 @@ int main(int argc, char* argv[])
const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths;
const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides;
Tensor<ADataType> a_ms_ks(
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()),
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end()));
Tensor<BDataType> b_ns_ks(
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()),
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<DDataType> d_ms_ns(
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
ck::index_t M_ = std::accumulate(e_ms_ns_lengths.begin(),
e_ms_ns_lengths.begin() + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N_ = std::accumulate(e_ms_ns_lengths.begin() + NumDimM,
e_ms_ns_lengths.begin() + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K_ = std::accumulate(a_ms_ks_lengths.begin() + NumDimM,
a_ms_ks_lengths.begin() + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<DDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
ck::index_t M_ = ck::accumulate_n(
e_ms_ns_lengths.begin(), NumDimM, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t N_ = ck::accumulate_n(e_ms_ns_lengths.begin() + NumDimM,
NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K_ = ck::accumulate_n(a_ms_ks_lengths.begin() + NumDimM,
NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
a_tensors.push_back(a_ms_ks);
b_tensors.push_back(b_ns_ks);
......@@ -334,13 +326,13 @@ int main(int argc, char* argv[])
flop += std::size_t(2) * M_ * K_ * N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
sizeof(BDataType) * b_tensors[i].GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].GetElementSize();
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_n_k: " << b_tensors[i].mDesc << " c_m_n: " << e_device_tensors[i].mDesc
<< std::endl;
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].GetDesc()
<< " b_n_k: " << b_tensors[i].GetDesc()
<< " c_m_n: " << e_device_tensors[i].GetDesc() << std::endl;
switch(init_method)
{
......@@ -364,18 +356,15 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < contraction_descs.size(); i++)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(DDataType) * d_tensors[i].mDesc.GetElementSpaceSize()));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d_tensors_device[i]->ToDevice(d_tensors[i].mData.data());
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(a_tensors[i].GetMemorySize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(b_tensors[i].GetMemorySize()));
d_tensors_device.emplace_back(std::make_unique<DeviceMem>(d_tensors[i].GetMemorySize()));
e_tensors_device.emplace_back(
std::make_unique<DeviceMem>(e_device_tensors[i].GetMemorySize()));
a_tensors_device[i]->ToDevice(a_tensors[i].data());
b_tensors_device[i]->ToDevice(b_tensors[i].data());
d_tensors_device[i]->ToDevice(d_tensors[i].data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
......@@ -423,15 +412,11 @@ int main(int argc, char* argv[])
const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths;
const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides;
Tensor<EDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data());
e_tensors_device[i]->FromDevice(e_device_tensors[i].data());
using ReferenceOpInstance = ReferenceContraction_M3_N2_K1<NumDimM,
NumDimN,
......@@ -456,15 +441,15 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0)
for(size_t m0 = 0; m0 < e_ms_ns_host_result.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1)
for(size_t m1 = 0; m1 < e_ms_ns_host_result.GetLengths()[1]; ++m1)
{
for(size_t m2 = 0; m2 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++m2)
for(size_t m2 = 0; m2 < e_ms_ns_host_result.GetLengths()[2]; ++m2)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n0)
for(size_t n0 = 0; n0 < e_ms_ns_host_result.GetLengths()[3]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[4]; ++n1)
for(size_t n1 = 0; n1 < e_ms_ns_host_result.GetLengths()[4]; ++n1)
{
cde_element_op(e_ms_ns_host_result(m0, m1, m2, n0, n1),
c_ms_ns_host_result(m0, m1, m2, n0, n1),
......@@ -475,7 +460,7 @@ int main(int argc, char* argv[])
}
}
pass &= ck::utils::check_err(e_device_tensors[i].mData, e_ms_ns_host_result.mData);
pass &= ck::utils::check_err(e_device_tensors[i], e_ms_ns_host_result);
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -109,7 +111,7 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B
float Run(const Argument& arg)
{
auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) {
const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4];
const int K0 = arg.a_gs_ms_ks_.GetLengths()[4];
AccDataType v_acc = 0;
......@@ -136,12 +138,12 @@ struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::B
};
make_ParallelTensorFunctor(f_ms_ns,
arg.e_gs_ms_ns_.mDesc.GetLengths()[0],
arg.e_gs_ms_ns_.mDesc.GetLengths()[1],
arg.e_gs_ms_ns_.mDesc.GetLengths()[2],
arg.e_gs_ms_ns_.mDesc.GetLengths()[3],
arg.e_gs_ms_ns_.mDesc.GetLengths()[4],
arg.e_gs_ms_ns_.mDesc.GetLengths()[5])(
arg.e_gs_ms_ns_.GetLengths()[0],
arg.e_gs_ms_ns_.GetLengths()[1],
arg.e_gs_ms_ns_.GetLengths()[2],
arg.e_gs_ms_ns_.GetLengths()[3],
arg.e_gs_ms_ns_.GetLengths()[4],
arg.e_gs_ms_ns_.GetLengths()[5])(
std::thread::hardware_concurrency());
return 0;
......@@ -246,26 +248,16 @@ int main(int argc, char* argv[])
exit(0);
}
Tensor<ADataType> a_gs_ms_ks(
std::vector<std::size_t>(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()),
std::vector<std::size_t>(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()));
Tensor<BDataType> b_gs_ns_ks(
std::vector<std::size_t>(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()),
std::vector<std::size_t>(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()));
Tensor<DDataType> d_gs_ms_ns(
std::vector<std::size_t>(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<EDataType> e_gs_ms_ns_device_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl;
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
Tensor<EDataType> e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.GetDesc() << std::endl;
std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.GetDesc() << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.GetDesc() << std::endl;
std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -282,15 +274,14 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a_gs_ms_ks.GetMemorySize());
DeviceMem b_device_buf(b_gs_ns_ks.GetMemorySize());
DeviceMem d_device_buf(d_gs_ms_ns.GetMemorySize());
DeviceMem e_device_buf(e_gs_ms_ns_device_result.GetMemorySize());
a_device_buf.ToDevice(a_gs_ms_ks.mData.data());
b_device_buf.ToDevice(b_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
a_device_buf.ToDevice(a_gs_ms_ks.data());
b_device_buf.ToDevice(b_gs_ns_ks.data());
d_device_buf.ToDevice(d_gs_ms_ns.data());
// set zero
e_device_buf.SetZero();
......@@ -299,19 +290,21 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
using ck::utils::to_array;
// device operation
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
to_array({d_device_buf.GetDeviceBuffer()}),
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
to_array({d_gs_ms_ns_lengths}),
to_array({d_gs_ms_ns_strides}),
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
......@@ -327,25 +320,23 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t G = std::accumulate(e_gs_ms_ns_lengths.begin(),
e_gs_ms_ns_lengths.begin() + NumDimG,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t G = ck::accumulate_n(
e_gs_ms_ns_lengths.begin(), NumDimG, ck::index_t{1}, std::multiplies<ck::index_t>{});
ck::index_t M = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t M = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG,
NumDimM,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM + NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = ck::accumulate_n(e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM,
NumDimN,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM + NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = ck::accumulate_n(a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM,
NumDimK,
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * G * M * N * K;
std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N +
......@@ -358,13 +349,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data());
e_device_buf.FromDevice(e_gs_ms_ns_device_result.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_ms_ns_host_result(
std::vector<std::size_t>(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()),
std::vector<std::size_t>(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()));
Tensor<CShuffleDataType> c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1<NumDimG,
NumDimM,
......@@ -386,18 +375,17 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0)
for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.GetLengths()[0]; ++g0)
{
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1)
for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.GetLengths()[1]; ++g1)
{
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0)
for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.GetLengths()[2]; ++m0)
{
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1)
for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.GetLengths()[3]; ++m1)
{
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0)
for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.GetLengths()[4]; ++n0)
{
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5];
++n1)
for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.GetLengths()[5]; ++n1)
{
cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
c_ms_ns_host_result(g0, g1, m0, m1, n0, n1),
......@@ -409,9 +397,7 @@ int main(int argc, char* argv[])
}
}
return ck::utils::check_err(e_gs_ms_ns_device_result.mData, e_gs_ms_ns_host_result.mData)
? 0
: 1;
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1;
}
return 0;
......
......@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
void print_helper_msg()
{
......@@ -57,11 +59,11 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
Tensor<OutUserDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutKernelDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl;
std::cout << "residual: " << residual.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
std::cout << "in: " << in.GetDesc() << std::endl;
std::cout << "wei: " << wei.GetDesc() << std::endl;
std::cout << "bias: " << bias.GetDesc() << std::endl;
std::cout << "residual: " << residual.GetDesc() << std::endl;
std::cout << "out: " << out_host.GetDesc() << std::endl;
switch(init_method)
{
......@@ -77,11 +79,11 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
bias.GenerateTensorValue(GeneratorTensor_3<OutUserDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize());
DeviceMem in_device_buf(in.GetMemorySize());
DeviceMem wei_device_buf(wei.GetMemorySize());
DeviceMem bias_device_buf(bias.GetMemorySize());
DeviceMem residual_device_buf(residual.GetMemorySize());
DeviceMem out_device_buf(out_device.GetMemorySize());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const Tensor<InKernelDataType> in_converted(in);
......@@ -89,75 +91,54 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
const Tensor<OutKernelDataType> bias_converted(bias);
const Tensor<OutKernelDataType> residual_converted(residual);
in_device_buf.ToDevice(in_converted.mData.data());
wei_device_buf.ToDevice(wei_converted.mData.data());
bias_device_buf.ToDevice(bias_converted.mData.data());
residual_device_buf.ToDevice(residual_converted.mData.data());
in_device_buf.ToDevice(in_converted.data());
wei_device_buf.ToDevice(wei_converted.data());
bias_device_buf.ToDevice(bias_converted.data());
residual_device_buf.ToDevice(residual_converted.data());
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
residual_device_buf.ToDevice(residual.mData.data());
in_device_buf.ToDevice(in.data());
wei_device_buf.ToDevice(wei.data());
bias_device_buf.ToDevice(bias.data());
residual_device_buf.ToDevice(residual.data());
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d1_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(bias_g_n_k_wos_desc.GetLengths(), d0_g_n_k_wos_lengths);
copy(bias_g_n_k_wos_desc.GetStrides(), d0_g_n_k_wos_strides);
copy(residual_g_n_k_wos_desc.GetLengths(), d1_g_n_k_wos_lengths);
copy(residual_g_n_k_wos_desc.GetStrides(), d1_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
using ck::utils::to_array;
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument =
conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{bias_device_buf.GetDeviceBuffer(),
residual_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 2>{
{d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 2>{
{d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
to_array({bias_device_buf.GetDeviceBuffer(), residual_device_buf.GetDeviceBuffer()}),
out_device_buf.GetDeviceBuffer(),
to_array(in_g_n_c_wis_desc.GetLengths()),
to_array(in_g_n_c_wis_desc.GetStrides()),
to_array(wei_g_k_c_xs_desc.GetLengths()),
to_array(wei_g_k_c_xs_desc.GetStrides()),
to_array({d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}),
to_array({d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}),
to_array(out_g_n_k_wos_desc.GetLengths()),
to_array(out_g_n_k_wos_desc.GetStrides()),
to_array(conv_param.conv_filter_strides_),
to_array(conv_param.conv_filter_dilations_),
to_array(conv_param.input_left_pads_),
to_array(conv_param.input_right_pads_),
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
......@@ -209,21 +190,17 @@ int run_grouped_conv_fwd_bias_relu_add(bool do_verification,
out_element_op(out_host(idx), c_host(idx), bias(idx), residual(idx));
});
out_device_buf.FromDevice(out_device.mData.data());
out_device_buf.FromDevice(out_device.data());
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
const Tensor<OutUserDataType> out_device_converted(out_device);
return ck::utils::check_err(out_device_converted.mData,
out_host.mData,
"Error: incorrect results!",
1e-5f,
1e-4f)
return ck::utils::check_err(
out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0
: 1;
#else // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
return ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f)
? 0
: 1;
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
......
......@@ -9,32 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include "common.hpp"
using ADataType = BF16;
using B0DataType = BF16;
......
......@@ -9,32 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include "common.hpp"
using ADataType = F16;
using B0DataType = F16;
......
......@@ -9,31 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include "common.hpp"
using ADataType = F32;
using B0DataType = F32;
......
......@@ -13,29 +13,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include "common.hpp"
using ADataType = ck::int4_t;
using B0DataType = ck::int4_t;
......
......@@ -9,29 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#include "common.hpp"
using ADataType = int8_t;
using B0DataType = int8_t;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -100,21 +100,21 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
if constexpr(std::is_same_v<decltype(layout), Row>)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
}
};
......@@ -130,10 +130,10 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -155,29 +155,27 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
}
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_g_m_k_device_buf(sizeof(KernelADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(KernelB0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(KernelB1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(KernelCDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_g_m_o_device_buf(c_g_m_o_device_result.GetMemorySize());
const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k);
const Tensor<KernelB0DataType> b0_g_k_n_converted(b0_g_k_n);
const Tensor<KernelB1DataType> b1_g_n_o_converted(b1_g_n_o);
a_g_m_k_device_buf.ToDevice(a_g_m_k_converted.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n_converted.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o_converted.mData.data());
a_g_m_k_device_buf.ToDevice(a_g_m_k_converted.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n_converted.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o_converted.data());
#else
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_g_m_o_device_buf(c_g_m_o_device_result.GetMemorySize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
#endif
auto a_element_op = AElementOp{};
......@@ -189,36 +187,28 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelB0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelB1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
auto argument = gemm.MakeArgument(a_g_m_k_device_buf.GetDeviceBuffer(),
b0_g_k_n_device_buf.GetDeviceBuffer(),
b1_g_n_o_device_buf.GetDeviceBuffer(),
c_g_m_o_device_buf.GetDeviceBuffer(),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -261,16 +251,16 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
ref_gemm1_invoker.Run(ref_gemm1_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<KernelCDataType> c_g_m_o_device_result_converted(c_g_m_o_host_result.mDesc);
Tensor<KernelCDataType> c_g_m_o_device_result_converted(c_g_m_o_host_result.GetDesc());
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result_converted.mData.data());
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result_converted.data());
c_g_m_o_device_result = c_g_m_o_device_result_converted.CopyAsType<CDataType>();
#else
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.data());
#endif
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData);
return ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result);
}
return true;
......
......@@ -9,23 +9,24 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -222,7 +223,9 @@ int main(int argc, char* argv[])
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
const int BatchCount = G0 * G1;
const ck::index_t BatchCount = G0 * G1;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
......@@ -230,15 +233,13 @@ int main(int argc, char* argv[])
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
if constexpr(std::is_same_v<decltype(layout), Row>)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
}
};
......@@ -249,17 +250,13 @@ int main(int argc, char* argv[])
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_host_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -285,15 +282,14 @@ int main(int argc, char* argv[])
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_gs_ms_os_device_buf(c_gs_ms_os_device_result.GetMemorySize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
......@@ -302,31 +298,30 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA,
StrideB0,
StrideB1,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_g_m_k_device_buf.GetDeviceBuffer(),
b0_g_k_n_device_buf.GetDeviceBuffer(),
b1_g_n_o_device_buf.GetDeviceBuffer(),
c_gs_ms_os_device_buf.GetDeviceBuffer(),
M,
N,
K,
O,
BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA,
StrideB0,
StrideB1,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -351,15 +346,14 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.data());
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
std::vector<int>{M * O, O, 1});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}, {M * O, O, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
......@@ -400,9 +394,7 @@ int main(int argc, char* argv[])
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
? 0
: 1;
return ck::utils::check_err(c_gs_ms_os_device_result, c_gs_ms_os_host_result) ? 0 : 1;
}
return 0;
......
......@@ -9,23 +9,24 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -222,7 +223,9 @@ int main(int argc, char* argv[])
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
const int BatchCount = G0 * G1;
const ck::index_t BatchCount = G0 * G1;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
......@@ -230,15 +233,13 @@ int main(int argc, char* argv[])
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
if constexpr(std::is_same_v<decltype(layout), Row>)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
}
};
......@@ -249,17 +250,13 @@ int main(int argc, char* argv[])
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_host_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -285,15 +282,14 @@ int main(int argc, char* argv[])
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_gs_ms_os_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_gs_ms_os_device_buf(c_gs_ms_os_device_result.GetMemorySize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
......@@ -302,31 +298,30 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gs_ms_os_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA,
StrideB0,
StrideB1,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_g_m_k_device_buf.GetDeviceBuffer(),
b0_g_k_n_device_buf.GetDeviceBuffer(),
b1_g_n_o_device_buf.GetDeviceBuffer(),
c_gs_ms_os_device_buf.GetDeviceBuffer(),
M,
N,
K,
O,
BatchCount,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
StrideA,
StrideB0,
StrideB1,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -351,15 +346,14 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.data());
// Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{BatchCount, M, O},
std::vector<int>{M * O, O, 1});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}, {M * O, O, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
......@@ -390,9 +384,7 @@ int main(int argc, char* argv[])
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
return ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData)
? 0
: 1;
return ck::utils::check_err(c_gs_ms_os_device_result, c_gs_ms_os_host_result) ? 0 : 1;
}
return 0;
......
......@@ -9,22 +9,23 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
*/
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -239,21 +240,21 @@ int main(int argc, char* argv[])
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
if constexpr(std::is_same_v<decltype(layout), Row>)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
}
};
......@@ -269,10 +270,10 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
std::cout << "a_g_m_k: " << a_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.GetDesc() << std::endl;
switch(init_method)
{
......@@ -298,15 +299,14 @@ int main(int argc, char* argv[])
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_g_m_k_device_buf(a_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem c_g_m_o_device_buf(c_g_m_o_device_result.GetMemorySize());
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
a_g_m_k_device_buf.ToDevice(a_g_m_k.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
......@@ -315,31 +315,30 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(a_g_m_k_device_buf.GetDeviceBuffer(),
b0_g_k_n_device_buf.GetDeviceBuffer(),
b1_g_n_o_device_buf.GetDeviceBuffer(),
c_g_m_o_device_buf.GetDeviceBuffer(),
M,
N,
K,
O,
BatchCount,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -362,7 +361,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.data());
if(do_verification)
{
......@@ -391,7 +390,7 @@ int main(int argc, char* argv[])
ref_gemm1_invoker.Run(ref_gemm1_argument);
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
return ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result) ? 0 : 1;
}
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment