Commit e4e99a49 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new utilities to shorten codes

parent 7acbf104
...@@ -9,23 +9,24 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -9,23 +9,24 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1 Gemm1
*/ */
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -211,21 +212,21 @@ int main(int argc, char* argv[]) ...@@ -211,21 +212,21 @@ int main(int argc, char* argv[])
c_gs_ms_os_strides}); c_gs_ms_os_strides});
} }
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), Row>::value) if constexpr(std::is_same_v<decltype(layout), Row>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -267,9 +268,7 @@ int main(int argc, char* argv[]) ...@@ -267,9 +268,7 @@ int main(int argc, char* argv[])
f_host_tensor_descriptor(Batch, K, N, StrideB0, BatchStrideB0, B0Layout{})); f_host_tensor_descriptor(Batch, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o( Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(Batch, N, O, StrideB1, BatchStrideB1, B1Layout{})); f_host_tensor_descriptor(Batch, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_device_result( Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
...@@ -278,10 +277,11 @@ int main(int argc, char* argv[]) ...@@ -278,10 +277,11 @@ int main(int argc, char* argv[])
if(i < 4) if(i < 4)
{ {
std::cout << "a_g_m_k[" << i << "]: " << a_g_m_k.mDesc << ", " std::cout << "a_g_m_k[" << i << "]: " << a_g_m_k.GetDesc() << ", "
<< "b0_g_k_n[" << i << "]: " << b0_g_k_n.mDesc << ", " << "b0_g_k_n[" << i << "]: " << b0_g_k_n.GetDesc() << ", "
<< "b1_g_n_o[" << i << "]: " << b1_g_n_o.mDesc << ", " << "b1_g_n_o[" << i << "]: " << b1_g_n_o.GetDesc() << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.GetDesc()
<< std::endl;
} }
switch(init_method) switch(init_method)
...@@ -313,18 +313,15 @@ int main(int argc, char* argv[]) ...@@ -313,18 +313,15 @@ int main(int argc, char* argv[])
b1_tensors.push_back(b1_g_n_o); b1_tensors.push_back(b1_g_n_o);
c_tensors.push_back(c_gs_ms_os_device_result); c_tensors.push_back(c_gs_ms_os_device_result);
a_tensors_device.emplace_back( a_tensors_device.emplace_back(std::make_unique<DeviceMem>(a_g_m_k.GetMemorySize()));
std::make_unique<DeviceMem>(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize())); b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(b0_g_k_n.GetMemorySize()));
b0_tensors_device.emplace_back( b1_tensors_device.emplace_back(std::make_unique<DeviceMem>(b1_g_n_o.GetMemorySize()));
std::make_unique<DeviceMem>(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize())); c_tensors_device.emplace_back(
b1_tensors_device.emplace_back( std::make_unique<DeviceMem>(c_gs_ms_os_device_result.GetMemorySize()));
std::make_unique<DeviceMem>(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_g_m_k.mData.data()); a_tensors_device[i]->ToDevice(a_g_m_k.data());
b0_tensors_device[i]->ToDevice(b0_g_k_n.mData.data()); b0_tensors_device[i]->ToDevice(b0_g_k_n.data());
b1_tensors_device[i]->ToDevice(b1_g_n_o.mData.data()); b1_tensors_device[i]->ToDevice(b1_g_n_o.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
...@@ -391,11 +388,9 @@ int main(int argc, char* argv[]) ...@@ -391,11 +388,9 @@ int main(int argc, char* argv[])
auto& c_gs_ms_os_device_result = c_tensors[i]; auto& c_gs_ms_os_device_result = c_tensors[i];
auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
Tensor<CDataType> c_gs_ms_os_host_result( Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.data());
// Output of Gemm0 is input A of Gemm1 // Output of Gemm0 is input A of Gemm1
Tensor<AccDataType> acc0_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{})); Tensor<AccDataType> acc0_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{}));
...@@ -434,8 +429,7 @@ int main(int argc, char* argv[]) ...@@ -434,8 +429,7 @@ int main(int argc, char* argv[])
c_gs_ms_os_host_result.ForEach( c_gs_ms_os_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = c_g_m_o_host_result(idx); }); [&](auto& self, auto idx) { self(idx) = c_g_m_o_host_result(idx); });
bool pass_ = bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result, c_gs_ms_os_host_result);
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
pass &= pass_; pass &= pass_;
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <algorithm>
#include <array>
#include <cstdlib> #include <cstdlib>
#include <iostream>
#include <vector> #include <vector>
#include <array>
#include <algorithm>
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
...@@ -210,8 +213,8 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -210,8 +213,8 @@ int mean_meansquare_dual_reduce_test(size_t n,
Tensor<OutDataType> meansquare_ref(outLengths); Tensor<OutDataType> meansquare_ref(outLengths);
Tensor<OutDataType> meansquare(outLengths); Tensor<OutDataType> meansquare(outLengths);
auto inStrides = in.mDesc.GetStrides(); auto inStrides = in.GetStrides();
auto outStrides = mean.mDesc.GetStrides(); auto outStrides = mean.GetStrides();
size_t invariant_total_length = n; size_t invariant_total_length = n;
size_t reduce_total_length = h * w * c; size_t reduce_total_length = h * w * c;
...@@ -233,11 +236,11 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -233,11 +236,11 @@ int mean_meansquare_dual_reduce_test(size_t n,
}; };
// these buffers are usually provided by the user application // these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_dev(in.GetMemorySize());
DeviceMem mean_dev(sizeof(OutDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(mean.GetMemorySize());
DeviceMem meansquare_dev(sizeof(OutDataType) * meansquare.mDesc.GetElementSpaceSize()); DeviceMem meansquare_dev(meansquare.GetMemorySize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.data());
if(do_verification) if(do_verification)
{ {
...@@ -245,25 +248,19 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -245,25 +248,19 @@ int mean_meansquare_dual_reduce_test(size_t n,
in, mean_ref, meansquare_ref, n, h, w, c); in, mean_ref, meansquare_ref, n, h, w, c);
}; };
constexpr ck::index_t NumInputDim = Rank;
constexpr ck::index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1; constexpr ck::index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
std::array<ck::index_t, NumInputDim> i_inLengths;
std::array<ck::index_t, NumInputDim> i_inStrides;
std::array<ck::index_t, NumOutputDim> i_outLengths;
std::array<ck::index_t, NumOutputDim> i_outStrides; std::array<ck::index_t, NumOutputDim> i_outStrides;
std::copy(inLengths.begin(), inLengths.end(), i_inLengths.begin()); ck::ranges::copy(outStrides, i_outStrides.begin());
std::copy(inStrides.begin(), inStrides.end(), i_inStrides.begin());
std::copy(outLengths.begin(), outLengths.end(), i_outLengths.begin());
std::copy(outStrides.begin(), outStrides.end(), i_outStrides.begin());
auto dual_reduce_op = DeviceDualReduce{}; using ck::utils::to_array;
auto dual_reduce_op = DeviceDualReduce{};
auto argument_ptr = dual_reduce_op.MakeArgumentPointer( auto argument_ptr = dual_reduce_op.MakeArgumentPointer(
i_inLengths, to_array(inLengths),
i_inStrides, to_array(inStrides),
i_outLengths, to_array(outLengths),
{i_outStrides, i_outStrides}, {i_outStrides, i_outStrides},
reduceDims, reduceDims,
{&alpha, &alpha}, {&alpha, &alpha},
...@@ -303,10 +300,10 @@ int mean_meansquare_dual_reduce_test(size_t n, ...@@ -303,10 +300,10 @@ int mean_meansquare_dual_reduce_test(size_t n,
if(do_verification) if(do_verification)
{ {
mean_dev.FromDevice(mean.mData.data()); mean_dev.FromDevice(mean.data());
meansquare_dev.FromDevice(meansquare.mData.data()); meansquare_dev.FromDevice(meansquare.data());
pass = pass && ck::utils::check_err(mean.mData, mean_ref.mData); pass = pass && ck::utils::check_err(mean, mean_ref);
pass = pass && ck::utils::check_err(meansquare.mData, meansquare_ref.mData); pass = pass && ck::utils::check_err(meansquare, meansquare_ref);
}; };
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -9,12 +9,14 @@ ...@@ -9,12 +9,14 @@
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp"
#include "batchnorm_forward_impl.hpp" #include "batchnorm_forward_impl.hpp"
...@@ -159,8 +161,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -159,8 +161,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
Tensor<AccDataType> resultRunningMean_ref(scaleBiasMeanVarLengths); Tensor<AccDataType> resultRunningMean_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultRunningVariance_ref(scaleBiasMeanVarLengths); Tensor<AccDataType> resultRunningVariance_ref(scaleBiasMeanVarLengths);
auto inOutStrides = x.mDesc.GetStrides(); auto inOutStrides = x.GetStrides();
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); auto scaleBiasMeanVarStrides = bnScale.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -231,32 +233,28 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -231,32 +233,28 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
}; };
// these buffers are usually provided by the user application // these buffers are usually provided by the user application
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(x.GetMemorySize());
DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(y.GetMemorySize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); DeviceMem bnScale_dev(bnScale.GetMemorySize());
DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize()); DeviceMem bnBias_dev(bnBias.GetMemorySize());
// mean_dev or resultSaveMean_dev // mean_dev or resultSaveMean_dev
DeviceMem resultSaveMean_dev(sizeof(AccDataType) * DeviceMem resultSaveMean_dev(resultSaveMean_ref.GetMemorySize());
resultSaveMean_ref.mDesc.GetElementSpaceSize());
// meansquare_dev or resultSaveInvVariance_dev // meansquare_dev or resultSaveInvVariance_dev
DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) * DeviceMem resultSaveInvVariance_dev(resultSaveInvVariance_ref.GetMemorySize());
resultSaveInvVariance_ref.mDesc.GetElementSpaceSize());
// resultRunningMean_dev // resultRunningMean_dev
DeviceMem resultRunningMean_dev(sizeof(AccDataType) * DeviceMem resultRunningMean_dev(resultRunningMean_ref.GetMemorySize());
resultRunningMean_ref.mDesc.GetElementSpaceSize());
// resultRunningVariance_dev // resultRunningVariance_dev
DeviceMem resultRunningVariance_dev(sizeof(AccDataType) * DeviceMem resultRunningVariance_dev(resultRunningVariance_ref.GetMemorySize());
resultRunningVariance_ref.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.data());
bnScale_dev.ToDevice(bnScale.mData.data()); bnScale_dev.ToDevice(bnScale.data());
bnBias_dev.ToDevice(bnBias.mData.data()); bnBias_dev.ToDevice(bnBias.data());
if(updateMovingAverage) if(updateMovingAverage)
{ {
resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data()); resultRunningMean_dev.ToDevice(resultRunningMean_ref.data());
resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data()); resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.data());
}; };
std::array<index_t, Rank> i_inOutLengths; std::array<index_t, Rank> i_inOutLengths;
...@@ -264,25 +262,21 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -264,25 +262,21 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin()); using ck::ranges::copy;
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(), copy(inOutLengths, i_inOutLengths.begin());
scaleBiasMeanVarLengths.end(), copy(inOutStrides, i_inOutStrides.begin());
i_scaleBiasMeanVarLengths.begin()); copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(), copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin());
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0; int result = 0;
// used for saving meansquare // used for saving meansquare
DeviceMem workspace(sizeof(AccDataType) * 2 * resultSaveMean_ref.mDesc.GetElementSpaceSize() + DeviceMem workspace(resultSaveMean_ref.GetMemorySize() * 2 + 128);
128);
void* p_tmp_mean = workspace.GetDeviceBuffer(); void* p_tmp_mean = workspace.GetDeviceBuffer();
void* p_tmp_meansquare = void* p_tmp_meansquare =
static_cast<char*>(p_tmp_mean) + static_cast<char*>(p_tmp_mean) + (resultSaveMean_ref.GetMemorySize() + 63) / 64 * 64;
(sizeof(AccDataType) * resultSaveMean_ref.mDesc.GetElementSpaceSize() + 63) / 64 * 64;
result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>( result = bnorm_fwd<InOutDataType, AccDataType, Rank, NumReduceDim, false>(
time_kernel, time_kernel,
...@@ -322,17 +316,17 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -322,17 +316,17 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutStrides, i_inOutStrides,
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
x.mData.data(), x.data(),
bnScale.mData.data(), bnScale.data(),
bnBias.mData.data(), bnBias.data(),
y_ref.mData.data(), y_ref.data(),
0.1, // exponentialAverageFactor 0.1, // exponentialAverageFactor
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, // resultRunningMean updateMovingAverage ? resultRunningMean_ref.data() : nullptr, // resultRunningMean
updateMovingAverage ? resultRunningVariance_ref.mData.data() updateMovingAverage ? resultRunningVariance_ref.data()
: nullptr, // resultRunningVariance : nullptr, // resultRunningVariance
epsilon, epsilon,
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, saveMeanAndInvVariance ? resultSaveMean_ref.data() : nullptr,
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr); saveMeanAndInvVariance ? resultSaveInvVariance_ref.data() : nullptr);
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
...@@ -346,21 +340,19 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -346,21 +340,19 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get()); (void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.data());
pass = pass && ck::utils::check_err(y.mData, y_ref.mData); pass = pass && ck::utils::check_err(y, y_ref);
if(updateMovingAverage) if(updateMovingAverage)
{ {
Tensor<AccDataType> resultRunningMean(scaleBiasMeanVarLengths); Tensor<AccDataType> resultRunningMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultRunningVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> resultRunningVariance(scaleBiasMeanVarLengths);
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); resultRunningMean_dev.FromDevice(resultRunningMean.data());
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); resultRunningVariance_dev.FromDevice(resultRunningVariance.data());
pass = pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref);
pass && ck::utils::check_err(resultRunningMean.mData, resultRunningMean_ref.mData); pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref);
pass = pass && ck::utils::check_err(resultRunningVariance.mData,
resultRunningVariance_ref.mData);
}; };
if(saveMeanAndInvVariance) if(saveMeanAndInvVariance)
...@@ -368,12 +360,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, ...@@ -368,12 +360,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); resultSaveMean_dev.FromDevice(resultSaveMean.data());
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.data());
pass = pass && ck::utils::check_err(resultSaveMean.mData, resultSaveMean_ref.mData); pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref);
pass = pass && ck::utils::check_err(resultSaveInvVariance.mData, pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref);
resultSaveInvVariance_ref.mData);
}; };
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <limits> #include <algorithm>
#include <array>
#include <iostream> #include <iostream>
#include <limits>
#include <vector> #include <vector>
#include <array>
#include <algorithm>
#include <getopt.h> #include <getopt.h>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
#include "batchnorm_infer_impl.hpp" #include "batchnorm_infer_impl.hpp"
...@@ -142,8 +145,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -142,8 +145,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
Tensor<AccDataType> estimatedMean(scaleBiasMeanVarLengths); Tensor<AccDataType> estimatedMean(scaleBiasMeanVarLengths);
Tensor<AccDataType> estimatedVariance(scaleBiasMeanVarLengths); Tensor<AccDataType> estimatedVariance(scaleBiasMeanVarLengths);
auto inOutStrides = x.mDesc.GetStrides(); auto inOutStrides = x.GetStrides();
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); auto scaleBiasMeanVarStrides = bnScale.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -201,22 +204,21 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -201,22 +204,21 @@ bool bnorm_infer_nhwc_test(bool do_verification,
}; };
// these buffers are usually provided by the user application // these buffers are usually provided by the user application
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(x.GetMemorySize());
DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(y.GetMemorySize());
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); DeviceMem bnScale_dev(bnScale.GetMemorySize());
DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize()); DeviceMem bnBias_dev(bnBias.GetMemorySize());
// mean_dev or resultSaveMean_dev // mean_dev or resultSaveMean_dev
DeviceMem estimatedMean_dev(sizeof(AccDataType) * estimatedMean.mDesc.GetElementSpaceSize()); DeviceMem estimatedMean_dev(estimatedMean.GetMemorySize());
// meansquare_dev or resultSaveInvVariance_dev // meansquare_dev or resultSaveInvVariance_dev
DeviceMem estimatedVariance_dev(sizeof(AccDataType) * DeviceMem estimatedVariance_dev(estimatedVariance.GetMemorySize());
estimatedVariance.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.data());
bnScale_dev.ToDevice(bnScale.mData.data()); bnScale_dev.ToDevice(bnScale.data());
bnBias_dev.ToDevice(bnBias.mData.data()); bnBias_dev.ToDevice(bnBias.data());
estimatedMean_dev.ToDevice(estimatedMean.mData.data()); estimatedMean_dev.ToDevice(estimatedMean.data());
estimatedVariance_dev.ToDevice(estimatedVariance.mData.data()); estimatedVariance_dev.ToDevice(estimatedVariance.data());
using ck::index_t; using ck::index_t;
...@@ -225,14 +227,12 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -225,14 +227,12 @@ bool bnorm_infer_nhwc_test(bool do_verification,
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides; std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
std::copy(inOutLengths.begin(), inOutLengths.end(), i_inOutLengths.begin()); using ck::ranges::copy;
std::copy(inOutStrides.begin(), inOutStrides.end(), i_inOutStrides.begin());
std::copy(scaleBiasMeanVarLengths.begin(), copy(inOutLengths, i_inOutLengths.begin());
scaleBiasMeanVarLengths.end(), copy(inOutStrides, i_inOutStrides.begin());
i_scaleBiasMeanVarLengths.begin()); copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin());
std::copy(scaleBiasMeanVarStrides.begin(), copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin());
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
int result = 0; int result = 0;
...@@ -261,19 +261,18 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -261,19 +261,18 @@ bool bnorm_infer_nhwc_test(bool do_verification,
{ {
auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{}; auto batchNormInfer_ref = ReferenceBatchNormInferInstance<InOutDataType, AccDataType>{};
auto argument_ptr_ref = auto argument_ptr_ref = batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
batchNormInfer_ref.MakeArgumentPointer(i_inOutLengths,
i_inOutStrides, i_inOutStrides,
i_inOutStrides, i_inOutStrides,
i_scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths,
i_scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides,
x.mData.data(), x.data(),
bnScale.mData.data(), bnScale.data(),
bnBias.mData.data(), bnBias.data(),
epsilon, epsilon,
estimatedMean.mData.data(), estimatedMean.data(),
estimatedVariance.mData.data(), estimatedVariance.data(),
y_ref.mData.data()); y_ref.data());
if(!batchNormInfer_ref.IsSupportedArgument(argument_ptr_ref.get())) if(!batchNormInfer_ref.IsSupportedArgument(argument_ptr_ref.get()))
{ {
...@@ -287,8 +286,8 @@ bool bnorm_infer_nhwc_test(bool do_verification, ...@@ -287,8 +286,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get()); (void)invoker_ptr_ref->Run(argument_ptr_ref.get());
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.data());
pass = pass && ck::utils::check_err(y.mData, y_ref.mData); pass = pass && ck::utils::check_err(y, y_ref);
}; };
return (pass); return (pass);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <initializer_list>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -32,17 +32,17 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -32,17 +32,17 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size; auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size;
using namespace ck::literals;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {stride, 1_uz});
std::vector<std::size_t>({stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor({row, col}, {1_uz, stride});
std::vector<std::size_t>({1, stride}));
} }
}; };
...@@ -50,9 +50,9 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -50,9 +50,9 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.GetDesc() << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.GetDesc() << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_device_result.GetDesc() << std::endl;
switch(config.init_method) switch(config.init_method)
{ {
...@@ -70,19 +70,19 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -70,19 +70,19 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(a_m_k.GetMemorySize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(b_k_n.GetMemorySize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(c_m_n_device_result.GetMemorySize());
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_m_k_converted(a_m_k); const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n); const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); a_m_k_device_buf.ToDevice(a_m_k_converted.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_converted.data());
#else #else
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.data());
#endif #endif
c_m_n_device_buf.SetZero(); c_m_n_device_buf.SetZero();
...@@ -93,15 +93,9 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -93,15 +93,9 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
#ifdef BUILD_INT4_EXAMPLE b_k_n_device_buf.GetDeviceBuffer(),
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()), c_m_n_device_buf.GetDeviceBuffer(),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
#endif
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
...@@ -125,7 +119,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -125,7 +119,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(config.do_verification) if(config.do_verification)
{ {
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.data());
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
...@@ -146,15 +140,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -146,15 +140,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(std::is_same<CDataType, ck::half_t>::value) if(std::is_same<CDataType, ck::half_t>::value)
{ {
pass &= ck::utils::check_err(c_m_n_device_result.mData, pass &= ck::utils::check_err(
c_m_n_host_result.mData, c_m_n_device_result, c_m_n_host_result, "fp16 incorrect result", 3e-3, 1e-3);
"fp16 incorrect result",
3e-3,
1e-3);
} }
else else
{ {
pass &= ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16; using ADataType = BF16;
using BDataType = BF16; using BDataType = BF16;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F32; using ADataType = F32;
using BDataType = F32; using BDataType = F32;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t; using ADataType = ck::int4_t;
using BDataType = ck::int4_t; using BDataType = ck::int4_t;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include "common.hpp"
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t; using ADataType = int8_t;
using BDataType = int8_t; using BDataType = int8_t;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <ctime>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h> #include <getopt.h>
#include <ctime>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp" #include "ck/tensor_operation/gpu/device/device_sparse_embedding3_forward_layernorm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
// using EmbType = float; // using EmbType = float;
// using IndexType = int64_t; // using IndexType = int64_t;
...@@ -86,12 +87,10 @@ int main() ...@@ -86,12 +87,10 @@ int main()
constexpr auto index_length = 2048; constexpr auto index_length = 2048;
constexpr AccDataType epsilon = 1e-4; constexpr AccDataType epsilon = 1e-4;
auto f_host_tensor_desc_1d = [](std::size_t len_) { auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); };
return HostTensorDescriptor(std::vector<std::size_t>({len_}));
};
auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) { auto f_host_tensor_desc_2d = [](std::size_t rows_, std::size_t cols_) {
return HostTensorDescriptor(std::vector<std::size_t>({rows_, cols_})); return HostTensorDescriptor({rows_, cols_});
}; };
using ReferenceInstance = using ReferenceInstance =
...@@ -129,29 +128,29 @@ int main() ...@@ -129,29 +128,29 @@ int main()
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0}); beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); DeviceMem emb_a_dev(emb_a.GetMemorySize());
DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); DeviceMem emb_b_dev(emb_b.GetMemorySize());
DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); DeviceMem emb_c_dev(emb_c.GetMemorySize());
DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); DeviceMem index_a_dev(index_a.GetMemorySize());
DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); DeviceMem index_b_dev(index_b.GetMemorySize());
DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); DeviceMem index_c_dev(index_c.GetMemorySize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(gamma.GetMemorySize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(beta.GetMemorySize());
DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); DeviceMem out_dev(out.GetMemorySize());
emb_a_dev.ToDevice(emb_a.mData.data()); emb_a_dev.ToDevice(emb_a.data());
emb_b_dev.ToDevice(emb_b.mData.data()); emb_b_dev.ToDevice(emb_b.data());
emb_c_dev.ToDevice(emb_c.mData.data()); emb_c_dev.ToDevice(emb_c.data());
index_a_dev.ToDevice(index_a.mData.data()); index_a_dev.ToDevice(index_a.data());
index_b_dev.ToDevice(index_b.mData.data()); index_b_dev.ToDevice(index_b.data());
index_c_dev.ToDevice(index_c.mData.data()); index_c_dev.ToDevice(index_c.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.data());
beta_dev.ToDevice(beta.mData.data()); beta_dev.ToDevice(beta.data());
auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{}; auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(), auto argument_ptr = device_instance.MakeArgumentPointer(out_dev.GetDeviceBuffer(),
...@@ -202,9 +201,8 @@ int main() ...@@ -202,9 +201,8 @@ int main()
auto ref_invoker = ref.MakeInvoker(); auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
out_dev.FromDevice(out_from_dev.mData.data()); out_dev.FromDevice(out_from_dev.data());
pass &= ck::utils::check_err( pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3);
out_from_dev.mData, out.mData, "Error: Incorrect results", 1e-3, 1e-3);
} }
double total_read = current_dim * index_length * 3 * sizeof(EmbType) + double total_read = current_dim * index_length * 3 * sizeof(EmbType) +
......
...@@ -5,21 +5,23 @@ ...@@ -5,21 +5,23 @@
Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o]
*/ */
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -308,21 +310,21 @@ int main(int argc, char* argv[]) ...@@ -308,21 +310,21 @@ int main(int argc, char* argv[])
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
using namespace ck::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride, std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), Row>::value) if constexpr(std::is_same_v<decltype(layout), Row>)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz});
std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride});
std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
...@@ -344,14 +346,14 @@ int main(int argc, char* argv[]) ...@@ -344,14 +346,14 @@ int main(int argc, char* argv[])
Tensor<E1DataType> e1_g_m_o_device_result( Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; std::cout << "a0_g_m_k: " << a0_g_m_k.GetDesc() << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; std::cout << "b0_g_k_n: " << b0_g_k_n.GetDesc() << std::endl;
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc std::cout << "d00_g_m_n: " << d00_g_m_n.GetDesc()
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl; << " size: " << d00_g_m_n.GetElementSpaceSize() << std::endl;
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc std::cout << "d01_g_m_n: " << d01_g_m_n.GetDesc()
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl; << " size: " << d01_g_m_n.GetElementSpaceSize() << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "b1_g_n_o: " << b1_g_n_o.GetDesc() << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -381,21 +383,20 @@ int main(int argc, char* argv[]) ...@@ -381,21 +383,20 @@ int main(int argc, char* argv[])
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1}); d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
} }
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); DeviceMem a0_g_m_k_device_buf(a0_g_m_k.GetMemorySize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); DeviceMem b0_g_k_n_device_buf(b0_g_k_n.GetMemorySize());
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize()); DeviceMem d00_g_m_n_device_buf(d00_g_m_n.GetMemorySize());
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize()); DeviceMem d01_g_m_n_device_buf(d01_g_m_n.GetMemorySize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); DeviceMem b1_g_n_o_device_buf(b1_g_n_o.GetMemorySize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * DeviceMem e1_g_m_o_device_buf(e1_g_m_o_device_result.GetMemorySize());
e1_g_m_o_device_result.mDesc.GetElementSize()); DeviceMem d1_g_m_o_device_buf(d1_g_m_o.GetMemorySize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.data());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); b0_g_k_n_device_buf.ToDevice(b0_g_k_n.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); d00_g_m_n_device_buf.ToDevice(d00_g_m_n.data());
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data()); d01_g_m_n_device_buf.ToDevice(d01_g_m_n.data());
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data()); b1_g_n_o_device_buf.ToDevice(b1_g_n_o.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); d1_g_m_o_device_buf.ToDevice(d1_g_m_o.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{}; auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -403,17 +404,18 @@ int main(int argc, char* argv[]) ...@@ -403,17 +404,18 @@ int main(int argc, char* argv[])
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{}; auto cde1_element_op = CDE1ElementOp{};
using ck::utils::to_array;
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = auto argument = gemm.MakeArgument(
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()), a0_g_m_k_device_buf.GetDeviceBuffer(),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), b0_g_k_n_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(), to_array({d00_g_m_n_device_buf.GetDeviceBuffer(), d01_g_m_n_device_buf.GetDeviceBuffer()}),
d01_g_m_n_device_buf.GetDeviceBuffer()}, b1_g_n_o_device_buf.GetDeviceBuffer(),
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), to_array({d1_g_m_o_device_buf.GetDeviceBuffer()}),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()}, e1_g_m_o_device_buf.GetDeviceBuffer(),
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
...@@ -421,15 +423,15 @@ int main(int argc, char* argv[]) ...@@ -421,15 +423,15 @@ int main(int argc, char* argv[])
BatchCount, BatchCount,
StrideA0, StrideA0,
StrideB0, StrideB0,
std::array<ck::index_t, 2>{StrideD00, StrideD01}, to_array({StrideD00, StrideD01}),
StrideB1, StrideB1,
std::array<ck::index_t, 1>{StrideD1}, to_array({StrideD1}),
StrideE1, StrideE1,
BatchStrideA0, BatchStrideA0,
BatchStrideB0, BatchStrideB0,
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01}, to_array({BatchStrideD00, BatchStrideD01}),
BatchStrideB1, BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1}, to_array({BatchStrideD1}),
BatchStrideE1, BatchStrideE1,
a0_element_op, a0_element_op,
b0_element_op, b0_element_op,
...@@ -460,7 +462,7 @@ int main(int argc, char* argv[]) ...@@ -460,7 +462,7 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.data());
if(do_verification) if(do_verification)
{ {
...@@ -511,8 +513,7 @@ int main(int argc, char* argv[]) ...@@ -511,8 +513,7 @@ int main(int argc, char* argv[])
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
}); });
return ck::utils::check_err(e1_g_m_o_device_result.mData, e1_g_m_o_host_result.mData) ? 0 return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1;
: 1;
} }
return 0; return 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
void print_helper_msg() void print_helper_msg()
{ {
...@@ -53,10 +55,10 @@ int run_conv_bwd_data_bias_relu(bool do_verification, ...@@ -53,10 +55,10 @@ int run_conv_bwd_data_bias_relu(bool do_verification,
Tensor<InDataType> in_host(in_g_n_c_wis_desc); Tensor<InDataType> in_host(in_g_n_c_wis_desc);
Tensor<InDataType> in_device(in_g_n_c_wis_desc); Tensor<InDataType> in_device(in_g_n_c_wis_desc);
std::cout << "out: " << out.mDesc << std::endl; std::cout << "out: " << out.GetDesc() << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "wei: " << wei.GetDesc() << std::endl;
std::cout << "bias: " << bias.mDesc << std::endl; std::cout << "bias: " << bias.GetDesc() << std::endl;
std::cout << "in: " << in_host.mDesc << std::endl; std::cout << "in: " << in_host.GetDesc() << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -72,66 +74,47 @@ int run_conv_bwd_data_bias_relu(bool do_verification, ...@@ -72,66 +74,47 @@ int run_conv_bwd_data_bias_relu(bool do_verification,
bias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{0.0, 1.0}); bias.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{0.0, 1.0});
} }
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(out.GetMemorySize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(wei.GetMemorySize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias.mDesc.GetElementSpaceSize()); DeviceMem bias_device_buf(bias.GetMemorySize());
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(in_device.GetMemorySize());
out_device_buf.ToDevice(out.mData.data()); out_device_buf.ToDevice(out.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.data());
bias_device_buf.ToDevice(bias.mData.data()); bias_device_buf.ToDevice(bias.data());
// reset input to zero // reset input to zero
in_device_buf.SetZero(); in_device_buf.SetZero();
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> d0_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> d0_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> d0_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_c_wis_strides{}; auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(bias_g_n_c_wis_desc.GetLengths(), d0_g_n_c_wis_lengths); copy(bias_g_n_c_wis_desc.GetLengths(), d0_g_n_c_wis_lengths);
copy(bias_g_n_c_wis_desc.GetStrides(), d0_g_n_c_wis_strides); copy(bias_g_n_c_wis_desc.GetStrides(), d0_g_n_c_wis_strides);
copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); using ck::utils::to_array;
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// do conv // do conv
auto conv = DeviceInstance{}; auto conv = DeviceInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument( auto argument = conv.MakeArgument(out_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()}, to_array({bias_device_buf.GetDeviceBuffer()}),
in_device_buf.GetDeviceBuffer(), in_device_buf.GetDeviceBuffer(),
a_g_n_k_wos_lengths, to_array(out_g_n_k_wos_desc.GetLengths()),
a_g_n_k_wos_strides, to_array(out_g_n_k_wos_desc.GetStrides()),
b_g_k_c_xs_lengths, to_array(wei_g_k_c_xs_desc.GetLengths()),
b_g_k_c_xs_strides, to_array(wei_g_k_c_xs_desc.GetStrides()),
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{d0_g_n_c_wis_lengths}, to_array({d0_g_n_c_wis_lengths}),
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{d0_g_n_c_wis_strides}, to_array({d0_g_n_c_wis_strides}),
e_g_n_c_wis_lengths, to_array(in_g_n_c_wis_desc.GetLengths()),
e_g_n_c_wis_strides, to_array(in_g_n_c_wis_desc.GetStrides()),
conv_filter_strides, to_array(conv_param.conv_filter_strides_),
conv_filter_dilations, to_array(conv_param.conv_filter_dilations_),
input_left_pads, to_array(conv_param.input_left_pads_),
input_right_pads, to_array(conv_param.input_right_pads_),
out_element_op, out_element_op,
wei_element_op, wei_element_op,
in_element_op); in_element_op);
...@@ -190,9 +173,9 @@ int run_conv_bwd_data_bias_relu(bool do_verification, ...@@ -190,9 +173,9 @@ int run_conv_bwd_data_bias_relu(bool do_verification,
in_host.ForEach( in_host.ForEach(
[&](auto&, auto idx) { in_element_op(in_host(idx), c_host(idx), bias(idx)); }); [&](auto&, auto idx) { in_element_op(in_host(idx), c_host(idx), bias(idx)); });
in_device_buf.FromDevice(in_device.mData.data()); in_device_buf.FromDevice(in_device.data());
return ck::utils::check_err(in_device.mData, in_host.mData) ? 0 : 1; return ck::utils::check_err(in_device, in_host) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -19,11 +19,14 @@ ...@@ -19,11 +19,14 @@
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/array.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp" #include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/ranges.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -150,37 +153,6 @@ struct is_random_access_iterator<Iterator, ...@@ -150,37 +153,6 @@ struct is_random_access_iterator<Iterator,
template <typename Iterator> template <typename Iterator>
inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value; inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value;
template <typename T, typename = void>
struct is_range : std::false_type
{
};
template <typename T>
struct is_range<T,
std::void_t<decltype(begin(std::declval<T>())),
decltype(end(std::declval<T>())),
decltype(begin(std::declval<T>()) != end(std::declval<T>()))>>
: std::bool_constant<is_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<T>()))>>>
{
};
template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;
template <typename Range, typename = void>
struct is_sized_range : std::false_type
{
};
template <typename Range>
struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
: std::bool_constant<is_range_v<Range>>
{
};
template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;
template <typename Range, typename = void> template <typename Range, typename = void>
struct is_bidirectional_range : std::false_type struct is_bidirectional_range : std::false_type
{ {
...@@ -189,7 +161,7 @@ struct is_bidirectional_range : std::false_type ...@@ -189,7 +161,7 @@ struct is_bidirectional_range : std::false_type
template <typename Range> template <typename Range>
struct is_bidirectional_range<Range, std::void_t<>> struct is_bidirectional_range<Range, std::void_t<>>
: std::bool_constant< : std::bool_constant<
is_range_v<Range> && ck::ranges::is_range_v<Range> &&
is_bidirectional_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>> is_bidirectional_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{ {
}; };
...@@ -205,7 +177,7 @@ struct is_random_access_range : std::false_type ...@@ -205,7 +177,7 @@ struct is_random_access_range : std::false_type
template <typename Range> template <typename Range>
struct is_random_access_range<Range, std::void_t<>> struct is_random_access_range<Range, std::void_t<>>
: std::bool_constant< : std::bool_constant<
is_range_v<Range> && ck::ranges::is_range_v<Range> &&
is_random_access_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>> is_random_access_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{ {
}; };
...@@ -213,53 +185,8 @@ struct is_random_access_range<Range, std::void_t<>> ...@@ -213,53 +185,8 @@ struct is_random_access_range<Range, std::void_t<>>
template <typename Range> template <typename Range>
inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value; inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value;
template <typename Range>
class to_array_proxy
{
static_assert(is_range_v<Range>);
public:
explicit to_array_proxy(const Range& source) noexcept : source_(source) {}
template <typename T, std::size_t Size>
operator std::array<T, Size>() const
{
std::array<T, Size> destination;
std::copy_n(std::begin(source_),
std::min<std::size_t>(Size, std::size(source_)),
std::begin(destination));
return destination;
}
private:
const Range& source_;
};
} // namespace detail } // namespace detail
template <typename Range>
inline auto to_array(Range& range) noexcept
-> std::enable_if_t<detail::is_range_v<Range>,
detail::to_array_proxy<ck::remove_cvref_t<Range>>>
{
return detail::to_array_proxy<ck::remove_cvref_t<Range>>{range};
}
namespace ranges {
template <typename InputRange, typename OutputIterator>
inline auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{
return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
}
} // namespace ranges
template <typename Axes> template <typename Axes>
inline auto is_valid_axes(const Axes& axes) inline auto is_valid_axes(const Axes& axes)
-> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool> -> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
...@@ -281,7 +208,8 @@ inline auto is_valid_axes(const Axes& axes) ...@@ -281,7 +208,8 @@ inline auto is_valid_axes(const Axes& axes)
} }
template <typename Shape> template <typename Shape>
inline auto is_valid_shape(const Shape& shape) -> std::enable_if_t<detail::is_range_v<Shape>, bool> inline auto is_valid_shape(const Shape& shape)
-> std::enable_if_t<ck::ranges::is_range_v<Shape>, bool>
{ {
static_assert(std::is_unsigned_v<ck::remove_cvref_t<decltype(*std::begin(shape))>>); static_assert(std::is_unsigned_v<ck::remove_cvref_t<decltype(*std::begin(shape))>>);
...@@ -291,8 +219,8 @@ inline auto is_valid_shape(const Shape& shape) -> std::enable_if_t<detail::is_ra ...@@ -291,8 +219,8 @@ inline auto is_valid_shape(const Shape& shape) -> std::enable_if_t<detail::is_ra
} }
template <typename Shape, typename Indices> template <typename Shape, typename Indices>
inline auto is_valid_indices(const Shape& shape, const Indices& indices) inline auto is_valid_indices(const Shape& shape, const Indices& indices) -> std::
-> std::enable_if_t<detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Indices>, bool> enable_if_t<ck::ranges::is_sized_range_v<Shape> && ck::ranges::is_sized_range_v<Indices>, bool>
{ {
static_assert(std::is_unsigned_v<ck::remove_cvref_t<decltype(*std::begin(indices))>>); static_assert(std::is_unsigned_v<ck::remove_cvref_t<decltype(*std::begin(indices))>>);
...@@ -348,9 +276,9 @@ auto extend_shape(const Problem::Shape& shape, std::size_t new_dim) ...@@ -348,9 +276,9 @@ auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
{ {
detail::enlarge_array_size_t<Problem::Shape, 1> extended_shape; detail::enlarge_array_size_t<Problem::Shape, 1> extended_shape;
using std::begin, std::end; using std::begin;
std::copy(begin(shape), end(shape), begin(extended_shape)); ck::ranges::copy(shape, begin(extended_shape));
extended_shape.back() = new_dim; extended_shape.back() = new_dim;
return extended_shape; return extended_shape;
...@@ -360,9 +288,9 @@ auto extend_axes(const Problem::Axes& axes) ...@@ -360,9 +288,9 @@ auto extend_axes(const Problem::Axes& axes)
{ {
detail::enlarge_array_size_t<Problem::Axes, 1> extended_axes; detail::enlarge_array_size_t<Problem::Axes, 1> extended_axes;
using std::begin, std::end; using std::begin;
std::copy(begin(axes), end(axes), begin(extended_axes)); ck::ranges::copy(axes, begin(extended_axes));
extended_axes.back() = detail::get_array_size_v<Problem::Axes>; extended_axes.back() = detail::get_array_size_v<Problem::Axes>;
return extended_axes; return extended_axes;
...@@ -370,8 +298,8 @@ auto extend_axes(const Problem::Axes& axes) ...@@ -370,8 +298,8 @@ auto extend_axes(const Problem::Axes& axes)
template <typename Shape, typename Indices> template <typename Shape, typename Indices>
auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t<
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> && detail::is_bidirectional_range_v<Shape> && ck::ranges::is_sized_range_v<Shape> &&
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>, detail::is_bidirectional_range_v<Indices> && ck::ranges::is_sized_range_v<Indices>,
bool> bool>
{ {
using std::size; using std::size;
...@@ -396,14 +324,15 @@ auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< ...@@ -396,14 +324,15 @@ auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t<
template <typename Src, typename Axes, typename Functor, typename Dest> template <typename Src, typename Axes, typename Functor, typename Dest>
auto host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<Dest>& dest) auto host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<Dest>& dest)
-> std::enable_if_t<detail::is_random_access_range_v<Axes> && detail::is_sized_range_v<Axes> && -> std::enable_if_t<detail::is_random_access_range_v<Axes> &&
ck::ranges::is_sized_range_v<Axes> &&
std::is_invocable_v<Functor, std::is_invocable_v<Functor,
std::add_lvalue_reference_t<Dest>, std::add_lvalue_reference_t<Dest>,
std::add_lvalue_reference_t<Src>>, std::add_lvalue_reference_t<Src>>,
bool> bool>
{ {
const auto& shape = src.mDesc.GetLengths(); const auto& shape = src.GetLengths();
const auto& transposed_shape = dest.mDesc.GetLengths(); const auto& transposed_shape = dest.GetLengths();
if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape))) if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape)))
{ {
return false; return false;
...@@ -415,8 +344,8 @@ auto host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Ten ...@@ -415,8 +344,8 @@ auto host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Ten
return false; return false;
} }
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> && static_assert(ck::ranges::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>); ck::ranges::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
if(size(shape) != size(transposed_shape)) if(size(shape) != size(transposed_shape))
{ {
......
...@@ -16,14 +16,16 @@ bool run_permute_bundle(const Problem& problem) ...@@ -16,14 +16,16 @@ bool run_permute_bundle(const Problem& problem)
// initialize tensor by assigning DataType values // initialize tensor by assigning DataType values
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(input_bundle_tensor.AsSpan<DataType>()); ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(input_bundle_tensor.AsSpan<DataType>());
DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes()); DeviceMem input_device_buf(input_bundle_tensor.GetMemorySize());
DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes()); DeviceMem output_device_buf(output_bundle_tensor.GetMemorySize());
using std::data; using std::data;
input_device_buf.ToDevice(data(input_bundle_tensor)); input_device_buf.ToDevice(data(input_bundle_tensor));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
using ck::utils::to_array;
auto permute = DevicePermuteInstance{}; auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(to_array(input_bundle_shape), auto argument = permute.MakeArgument(to_array(input_bundle_shape),
to_array(input_bundle_tensor.GetStrides()), to_array(input_bundle_tensor.GetStrides()),
...@@ -57,7 +59,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -57,7 +59,7 @@ bool run_permute_bundle(const Problem& problem)
using std::begin; using std::begin;
Tensor<DataType> input_tensor(input_shape); Tensor<DataType> input_tensor(input_shape);
ranges::copy(input_bundle_tensor.AsSpan<const DataType>(), begin(input_tensor)); ck::ranges::copy(input_bundle_tensor.AsSpan<const DataType>(), begin(input_tensor));
Tensor<DataType> output_tensor(transpose(input_shape, input_axes)); Tensor<DataType> output_tensor(transpose(input_shape, input_axes));
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor)) if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
......
...@@ -15,14 +15,16 @@ bool run_permute_element(const Problem& problem) ...@@ -15,14 +15,16 @@ bool run_permute_element(const Problem& problem)
ck::utils::FillUniformDistribution<InDataType>{-1.f, 1.f}(input_tensor); ck::utils::FillUniformDistribution<InDataType>{-1.f, 1.f}(input_tensor);
DeviceMem input_device_buf(input_tensor.GetElementSpaceSizeInBytes()); DeviceMem input_device_buf(input_tensor.GetMemorySize());
DeviceMem output_device_buf(output_tensor.GetElementSpaceSizeInBytes()); DeviceMem output_device_buf(output_tensor.GetMemorySize());
using std::data; using std::data;
input_device_buf.ToDevice(data(input_tensor)); input_device_buf.ToDevice(data(input_tensor));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>); static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
using ck::utils::to_array;
auto permute = DevicePermuteInstance{}; auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(to_array(input_shape), auto argument = permute.MakeArgument(to_array(input_shape),
to_array(input_tensor.GetStrides()), to_array(input_tensor.GetStrides()),
......
...@@ -7,17 +7,19 @@ ...@@ -7,17 +7,19 @@
#include <type_traits> #include <type_traits>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using In0DataType = ck::bhalf_t; using In0DataType = ck::bhalf_t;
using Wei0DataType = ck::bhalf_t; using Wei0DataType = ck::bhalf_t;
......
...@@ -7,17 +7,19 @@ ...@@ -7,17 +7,19 @@
#include <type_traits> #include <type_traits>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using In0DataType = ck::half_t; using In0DataType = ck::half_t;
using Wei0DataType = ck::half_t; using Wei0DataType = ck::half_t;
......
...@@ -7,17 +7,19 @@ ...@@ -7,17 +7,19 @@
#include <type_traits> #include <type_traits>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using In0DataType = float; using In0DataType = float;
using Wei0DataType = float; using Wei0DataType = float;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment