Commit f9c478e2 authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into bmatrix_skip_lds

parents 7d85d04a 91d8b7d6
...@@ -21,9 +21,9 @@ struct TupleElement ...@@ -21,9 +21,9 @@ struct TupleElement
{ {
__host__ __device__ constexpr TupleElement() = default; __host__ __device__ constexpr TupleElement() = default;
template <typename T, template <
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value, typename T,
bool>::type = false> typename enable_if<!is_same<remove_cvref_t<T>, TupleElement>::value, bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
} }
...@@ -60,7 +60,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs> ...@@ -60,7 +60,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
template <typename Y, template <typename Y,
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 && typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
...@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr Tuple() = default; __host__ __device__ constexpr Tuple() = default;
template <typename Y, template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
......
...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type; ...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
......
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct DeviceConvFwdPtr_t
{
using BaseArgument = ck::tensor_operation::device::BaseArgument;
using BaseInvoker = ck::tensor_operation::device::BaseInvoker;
struct DeviceConvFwdPtrImpl;
std::unique_ptr<DeviceConvFwdPtrImpl> pImpl;
DeviceConvFwdPtr_t();
~DeviceConvFwdPtr_t();
DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&);
DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&);
DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete;
DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete;
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
const; // in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std::unique_ptr<BaseInvoker>
MakeInvokerPointer() const; // requires including BaseInvoker headers
std::string GetTypeString();
bool IsSupportedArgument(const BaseArgument* arg_ptr);
};
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances);
#ifndef DEVICE_HPP #pragma once
#define DEVICE_HPP
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <thread> #include <thread>
#include <chrono> #include <chrono>
#include "hip/hip_runtime.h" #include <hip/hip_runtime.h>
#include "hip/hip_fp16.h" #include <hip/hip_fp16.h>
#include "stream_config.hpp"
#include "ck/options.hpp"
template <typename T>
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
{
for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
{
p[i] = x;
}
}
inline void hip_check_error(hipError_t x)
{
if(x != hipSuccess)
{
std::ostringstream ss;
ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__
<< "in function: " << __func__;
throw std::runtime_error(ss.str());
}
}
struct DeviceMem struct DeviceMem
{ {
...@@ -17,6 +39,16 @@ struct DeviceMem ...@@ -17,6 +39,16 @@ struct DeviceMem
void ToDevice(const void* p); void ToDevice(const void* p);
void FromDevice(void* p); void FromDevice(void* p);
void SetZero(); void SetZero();
template <typename T>
void SetValue(T x)
{
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
~DeviceMem(); ~DeviceMem();
void* mpDeviceBuf; void* mpDeviceBuf;
...@@ -36,49 +68,56 @@ struct KernelTimer ...@@ -36,49 +68,56 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl; std::unique_ptr<KernelTimerImpl> impl;
}; };
using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) float launch_and_time_kernel(const StreamConfig& stream_config,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{ {
hipStream_t stream_id = nullptr; #if CK_TIME_KERNEL
if(stream_config.time_kernel_)
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); {
} printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
template <typename... Args, typename F> const int nrepeat = 10;
float launch_and_time_kernel(
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
KernelTimer timer;
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", printf("Warm up 1 time\n");
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up\n"); // warm up
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hipStream_t stream_id = nullptr; printf("Start running %d times...\n", nrepeat);
// warm up KernelTimer timer;
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); timer.Start();
printf("Start running %d times...\n", nrepeat); for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
}
timer.Start(); timer.End();
for(int i = 0; i < nrepeat; ++i) return timer.GetElapsedTime() / nrepeat;
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
} }
else
{
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
timer.End(); return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
return timer.GetElapsedTime() / nrepeat; return 0;
}
#endif #endif
}
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_COMMON_UTIL_HPP
#define GUARD_HOST_COMMON_UTIL_HPP
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include "config.hpp"
namespace ck {
namespace host_common {
template <typename T>
static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
{
std::ofstream outFile(fileName, std::ios::binary);
if(outFile)
{
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
outFile.close();
std::cout << "Write output to file " << fileName << std::endl;
}
else
{
std::cout << "Could not open file " << fileName << " for writing" << std::endl;
}
};
template <typename T>
static inline T getSingleValueFromString(const std::string& valueStr)
{
std::istringstream iss(valueStr);
T val;
iss >> val;
return (val);
};
template <typename T>
static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
{
std::string valuesStr(cstr_values);
std::vector<T> values;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = valuesStr.find(',', pos);
while(new_pos != std::string::npos)
{
const std::string sliceStr = valuesStr.substr(pos, new_pos - pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
pos = new_pos + 1;
new_pos = valuesStr.find(',', pos);
};
std::string sliceStr = valuesStr.substr(pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
return (values);
}
}; // namespace host_common
}; // namespace ck
#endif
...@@ -28,9 +28,7 @@ ...@@ -28,9 +28,7 @@
#include <limits> #include <limits>
#include <cmath> #include <cmath>
#include <cassert> #include <functional>
#include <stdexcept>
#include <string>
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "data_type.hpp" #include "data_type.hpp"
...@@ -214,13 +212,13 @@ binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce, ...@@ -214,13 +212,13 @@ binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
}; };
}; };
template <typename AccDataType, bool PropagateNan> template <typename AccDataType, typename IndexDataType, bool PropagateNan>
__host__ static inline void __host__ static inline void
binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opReduce, binop_with_index_and_nan_check(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal, AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
int& accuIndex, IndexDataType& accuIndex,
int currIndex) IndexDataType currIndex)
{ {
using ck::math::isnan; using ck::math::isnan;
...@@ -254,16 +252,6 @@ binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opRe ...@@ -254,16 +252,6 @@ binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opRe
}; // namespace host_reduce }; // namespace host_reduce
static inline std::vector<int> to_int_vector(const std::vector<size_t>& inData)
{
std::vector<int> outData;
for(auto elem : inData)
outData.push_back(static_cast<int>(elem));
return (outData);
};
}; // namespace ck }; // namespace ck
#endif #endif
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "host_reduce_util.hpp" #include "host_reduce_util.hpp"
#include "host_common_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
...@@ -200,7 +201,7 @@ struct ReductionHost ...@@ -200,7 +201,7 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check2; using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2; using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal; using ck::host_reduce::ReduceOpZeroVal;
...@@ -211,7 +212,7 @@ struct ReductionHost ...@@ -211,7 +212,7 @@ struct ReductionHost
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -220,9 +221,9 @@ struct ReductionHost ...@@ -220,9 +221,9 @@ struct ReductionHost
preUnaryOp(currVal); preUnaryOp(currVal);
auto currIndex = i; auto currIndex = static_cast<IndexDataType>(i);
binop_with_nan_check2<AccDataType, PropagateNan>( binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex); opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
...@@ -246,7 +247,7 @@ struct ReductionHost ...@@ -246,7 +247,7 @@ struct ReductionHost
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -256,9 +257,9 @@ struct ReductionHost ...@@ -256,9 +257,9 @@ struct ReductionHost
preUnaryOp(currVal); preUnaryOp(currVal);
auto currIndex = i; auto currIndex = static_cast<IndexDataType>(i);
binop_with_nan_check2<AccDataType, PropagateNan>( binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex); opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
......
...@@ -154,7 +154,7 @@ struct ParallelTensorFunctor ...@@ -154,7 +154,7 @@ struct ParallelTensorFunctor
{ {
std::array<std::size_t, NDIM> indices; std::array<std::size_t, NDIM> indices;
for(int idim = 0; idim < NDIM; ++idim) for(std::size_t idim = 0; idim < NDIM; ++idim)
{ {
indices[idim] = i / mStrides[idim]; indices[idim] = i / mStrides[idim];
i -= indices[idim] * mStrides[idim]; i -= indices[idim] * mStrides[idim];
...@@ -316,7 +316,7 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -316,7 +316,7 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
constexpr float eps = 1e-10; constexpr float eps = 1e-10;
for(int i = 0; i < ref.mData.size(); ++i) for(std::size_t i = 0; i < ref.mData.size(); ++i)
{ {
float ref_v = ck::type_convert<float>(ref.mData[i]); float ref_v = ck::type_convert<float>(ref.mData[i]);
float result_v = ck::type_convert<float>(result.mData[i]); float result_v = ck::type_convert<float>(result.mData[i]);
......
...@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator ...@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_CONV_WRW_HPP #pragma once
#define REFERENCE_CONV_WRW_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -16,7 +15,9 @@ template <typename InDataType, ...@@ -16,7 +15,9 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi}, : input_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x}, weight_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo}, output_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
} }
const Tensor<InDataType>& in_n_c_hi_wi_; const Tensor<InDataType>& input_;
Tensor<WeiDataType>& wei_k_c_y_x_; Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& out_n_k_ho_wo_; const Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
...@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
constexpr auto I0 = Number<0>{}; if constexpr(NumDimSpatial == 1)
constexpr auto I1 = Number<1>{}; {
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { constexpr auto I0 = Number<0>{};
float v_acc = 0; auto f_kcx = [&](auto k, auto c, auto x) {
for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) float v_acc = 0;
{ for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
{ {
int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
arg.in_left_pads_[I0];
for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{ {
int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - auto wi =
arg.in_left_pads_[I1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(v_out,
v_out, ck::type_convert<float>(arg.output_(n, k, wo)));
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo))); arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(arg.input_(n, c, wi)));
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} float v_wei;
float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert<OutDataType>(v_wei); arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcx,
arg.wei_k_c_y_x_.mDesc.GetLengths()[0], arg.weight_.mDesc.GetLengths()[0],
arg.wei_k_c_y_x_.mDesc.GetLengths()[1], arg.weight_.mDesc.GetLengths()[1],
arg.wei_k_c_y_x_.mDesc.GetLengths()[2], arg.weight_.mDesc.GetLengths()[2])(
arg.wei_k_c_y_x_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency());
std::thread::hardware_concurrency());
return 0; return 0;
}
else if constexpr(NumDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0]; auto w_tmp = ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.conv_strides_[0] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[0]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; auto h_tmp = ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[0]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = auto w_tmp =
wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.conv_strides_[1] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[1]; auto wo = ck::type_convert<ck::long_index_t>(w_tmp) /
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(int z = 0; z < Z; ++z) for(std::size_t z = 0; z < Z; ++z)
{ {
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0]; auto d_tmp = ck::type_convert<ck::long_index_t>(di) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]) -
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.conv_strides_[0] == 0)
{ {
int do_ = d_tmp / arg.conv_strides_[0]; auto do_ = ck::type_convert<ck::long_index_t>(d_tmp) /
if(do_ >= 0 && do_ < Do) ck::type_convert<ck::long_index_t>(arg.conv_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(int y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{ {
int h_tmp = auto h_tmp =
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1]; ck::type_convert<ck::long_index_t>(hi) +
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]) -
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.conv_strides_[1] == 0)
{ {
int ho = h_tmp / arg.conv_strides_[1]; auto ho = ck::type_convert<ck::long_index_t>(h_tmp) /
if(ho >= 0 && ho < Ho) ck::type_convert<ck::long_index_t>(
arg.conv_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(int x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
{ {
int w_tmp = wi + arg.in_left_pads_[2] - auto w_tmp =
x * arg.conv_dilations_[2]; ck::type_convert<ck::long_index_t>(wi) +
ck::type_convert<ck::long_index_t>(
arg.in_left_pads_[2]) -
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.conv_strides_[2] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[2]; auto wo =
if(wo >= 0 && wo < Wo) ck::type_convert<ck::long_index_t>(w_tmp) /
ck::type_convert<ck::long_index_t>(
arg.conv_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{ {
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
AccDataType v_out = 0; AccDataType v_out = 0;
AccDataType v_wei = 0; AccDataType v_wei = 0;
...@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_CONV_FWD_HPP #pragma once
#define REFERENCE_CONV_FWD_HPP
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include <sstream> #include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_ncw = [&](auto n, auto k, auto wo) { auto f_ncw = [&](auto n, auto k, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
{ {
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - auto wi =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[0]) +
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.input_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{ {
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{ {
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - auto di =
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{ {
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - auto hi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[1]) +
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{ {
int wi = wo * arg.conv_strides_[2] + auto wi =
x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; ck::type_convert<ck::long_index_t>(wo *
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] && arg.conv_strides_[2]) +
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] && ck::type_convert<ck::long_index_t>(x *
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4]) arg.conv_dilations_[2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator ...@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
float v_acc = 0; float v_acc = 0;
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[0]) +
arg.in_left_pads_[0]; ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[0]) -
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[0]);
for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - auto wi =
arg.in_left_pads_[1]; ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[1]) +
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[1]) -
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) ck::type_convert<ck::long_index_t>(arg.in_left_pads_[1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_in; float v_in;
float v_wei; float v_wei;
...@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator ...@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
#ifndef REFERENCE_GEMM_HPP #pragma once
#define REFERENCE_GEMM_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device_base.hpp" #include "device_base.hpp"
...@@ -13,6 +11,7 @@ namespace host { ...@@ -13,6 +11,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
float v_a; AccDataType v_a;
float v_b; AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
...@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
...@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator ...@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator ...@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
return 0; return 0;
} }
float Run(const device::BaseArgument* p_arg, int) override float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg)); return Run(*dynamic_cast<const Argument*>(p_arg));
} }
......
...@@ -9,26 +9,11 @@ ...@@ -9,26 +9,11 @@
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp" #include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp" #include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp" #include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" #include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
......
...@@ -3,13 +3,27 @@ ...@@ -3,13 +3,27 @@
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp" #include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise.hpp" #include "device_reduce_multiblock.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
using reduce_configuration_1_instances_blockwise = std::tuple<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1<256, 128, 2>,
ReductionConfiguration_1<256, 64, 4>,
ReductionConfiguration_1<256, 32, 8>,
ReductionConfiguration_1<256, 16, 16>,
ReductionConfiguration_1<256, 8, 32>,
ReductionConfiguration_1<256, 4, 64>,
ReductionConfiguration_1<256, 2, 128>,
ReductionConfiguration_1<256, 1, 256>
// clang-format on
>;
#ifdef QUICK_REDUCE_TEST #ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise = std::tuple< using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off // clang-format off
...@@ -58,8 +72,8 @@ template <typename InDataType, ...@@ -58,8 +72,8 @@ template <typename InDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation NanOpt, bool PropagateNan,
ReduceTensorIndices IndicesOpt> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{ {
...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise( ...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); constexpr bool OutputIndex = Indexable && UseIndex;
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) { using cfg1 = remove_cvref_t<decltype(
using cfg1 = std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}( [&](auto j) {
[&](auto j) { using cfg2 = remove_cvref_t<decltype(
using cfg2 = remove_cvref_t<decltype( std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance =
using ReduceOpInstance = DeviceReduceBlockWise<InDataType, DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
PropagateNan, InMemoryDataOperationEnum::Set,
NeedIndices, PropagateNan,
cfg1::BlockSize_, OutputIndex,
cfg1::MThreadClusterSize_, false, // HaveIndexInputIfOutputIndex
cfg1::KThreadClusterSize_, cfg1::BlockSize_,
cfg2::MThreadSliceSize_, cfg1::MThreadClusterSize_,
cfg2::KThreadSliceSize_, cfg1::KThreadClusterSize_,
cfg2::InSrcVectorDim_, cfg2::MThreadSliceSize_,
cfg2::InSrcVectorSize_, cfg2::KThreadSliceSize_,
cfg2::OutDstVectorSize_>; cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
device_op_instances.push_back( cfg2::OutDstVectorSize_>;
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
}); device_op_instances.push_back(
}); std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE( \ #define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \ template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \ typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \ typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment