Commit 6dfb4e78 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents 397a68f2 1ced00a5
...@@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
return base::operator()(i); return base::operator()(i);
} }
__host__ __device__ void Clear()
{
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; });
}
}; };
#ifndef CK_NOGPU #ifndef CK_NOGPU
...@@ -147,9 +152,9 @@ struct StaticBufferTupleOfVector ...@@ -147,9 +152,9 @@ struct StaticBufferTupleOfVector
__host__ __device__ void Clear() __host__ __device__ void Clear()
{ {
const index_t numScalars = NumOfVector * ScalarPerVector; constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
static_for<0, Number<numScalars>{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
} }
}; };
#endif #endif
......
...@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x) ...@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return r; return r;
} }
// MultiIndex = MultiIndex * index_t
template <typename... Xs>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, index_t a)
{
return a * x;
}
template <typename... Xs> template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x) __host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{ {
......
...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type; ...@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T> template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
......
...@@ -12,6 +12,15 @@ ...@@ -12,6 +12,15 @@
#include "stream_config.hpp" #include "stream_config.hpp"
#ifndef CK_NOGPU #ifndef CK_NOGPU
template <typename T>
__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
{
for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
{
p[i] = x;
}
}
inline void hip_check_error(hipError_t x) inline void hip_check_error(hipError_t x)
{ {
if(x != hipSuccess) if(x != hipSuccess)
...@@ -32,6 +41,16 @@ struct DeviceMem ...@@ -32,6 +41,16 @@ struct DeviceMem
void ToDevice(const void* p); void ToDevice(const void* p);
void FromDevice(void* p); void FromDevice(void* p);
void SetZero(); void SetZero();
template <typename T>
void SetValue(T x)
{
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
~DeviceMem(); ~DeviceMem();
void* mpDeviceBuf; void* mpDeviceBuf;
...@@ -78,8 +97,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -78,8 +97,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf("Warm up 1 time\n"); printf("Warm up 1 time\n");
// warm up // warm up
hipLaunchKernelGGL( kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
...@@ -88,8 +106,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -88,8 +106,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
for(int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
hipLaunchKernelGGL( kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
} }
timer.End(); timer.End();
...@@ -98,13 +115,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -98,13 +115,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
} }
else else
{ {
hipLaunchKernelGGL( kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...);
return 0; return 0;
} }
#else #else
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_config.stream_id_, args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
return 0; return 0;
#endif #endif
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_COMMON_UTIL_HPP
#define GUARD_HOST_COMMON_UTIL_HPP
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include "config.hpp"
namespace ck {
namespace host_common {
template <typename T>
static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
{
std::ofstream outFile(fileName, std::ios::binary);
if(outFile)
{
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
outFile.close();
std::cout << "Write output to file " << fileName << std::endl;
}
else
{
std::cout << "Could not open file " << fileName << " for writing" << std::endl;
}
};
template <typename T>
static inline T getSingleValueFromString(const std::string& valueStr)
{
std::istringstream iss(valueStr);
T val;
iss >> val;
return (val);
};
template <typename T>
static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
{
std::string valuesStr(cstr_values);
std::vector<T> values;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = valuesStr.find(',', pos);
while(new_pos != std::string::npos)
{
const std::string sliceStr = valuesStr.substr(pos, new_pos - pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
pos = new_pos + 1;
new_pos = valuesStr.find(',', pos);
};
std::string sliceStr = valuesStr.substr(pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
return (values);
}
}; // namespace host_common
}; // namespace ck
#endif
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <cassert>
#include <stdexcept>
#include <string>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation;
using ck::ReduceTensorOp;
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{
using ck::math::abs;
if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return ([&](AccDataType&) {});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{
return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return ([&](AccDataType&) {});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return (static_cast<AccDataType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return (ck::NumericLimits<AccDataType>::Max());
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{
return (ck::NumericLimits<AccDataType>::Lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return (static_cast<AccDataType>(0.0f));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f));
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
AccDataType& accuVal,
AccDataType currVal)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal,
AccDataType currVal,
int& accuIndex,
int currIndex)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
static inline std::vector<int> to_int_vector(const std::vector<size_t>& inData)
{
std::vector<int> outData;
for(auto elem : inData)
outData.push_back(static_cast<int>(elem));
return (outData);
};
}; // namespace ck
#endif
...@@ -33,9 +33,10 @@ ...@@ -33,9 +33,10 @@
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "host_reduce_util.hpp" #include "host_common_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template <int NDim> template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths, static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
...@@ -105,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides, ...@@ -105,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
ck::ReduceTensorOp ReduceOpId, typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
bool PropagateNan, bool PropagateNan,
bool NeedIndices> bool OutputIndex>
struct ReductionHost struct ReductionHost
{ {
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -121,8 +124,6 @@ struct ReductionHost ...@@ -121,8 +124,6 @@ struct ReductionHost
std::vector<int> reduceDims; std::vector<int> reduceDims;
IndexDataType divider; IndexDataType divider;
std::function<void(AccDataType&)> preUnaryOp;
std::function<void(AccDataType&)> posUnaryOp;
std::array<size_t, NumReduceDim> reduceLengths; std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides; std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths; std::array<size_t, NumInvariantDim> invariantLengths;
...@@ -136,9 +137,6 @@ struct ReductionHost ...@@ -136,9 +137,6 @@ struct ReductionHost
const std::vector<int>& invariantDims_, const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_) const std::vector<int>& reduceDims_)
{ {
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
// this->outLengths = to_int_vector(outDesc.GetLengths()); // this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides(); this->outStrides = outDesc.GetStrides();
...@@ -170,9 +168,6 @@ struct ReductionHost ...@@ -170,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes.clear(); invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes); get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
}; };
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
}; };
void Run(float alpha, void Run(float alpha,
...@@ -181,7 +176,7 @@ struct ReductionHost ...@@ -181,7 +176,7 @@ struct ReductionHost
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices)
{ {
if constexpr(NeedIndices) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
} }
...@@ -200,34 +195,34 @@ struct ReductionHost ...@@ -200,34 +195,34 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check2;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(IndexDataType i = 0; i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size()); for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = i; auto currIndex = static_cast<IndexDataType>(i);
binop_with_nan_check2<AccDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -241,15 +236,13 @@ struct ReductionHost ...@@ -241,15 +236,13 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(IndexDataType i = 0; for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size());
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -257,15 +250,14 @@ struct ReductionHost ...@@ -257,15 +250,14 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = i; auto currIndex = static_cast<IndexDataType>(i);
binop_with_nan_check2<AccDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -310,15 +302,16 @@ struct ReductionHost ...@@ -310,15 +302,16 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -327,12 +320,12 @@ struct ReductionHost ...@@ -327,12 +320,12 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -345,7 +338,7 @@ struct ReductionHost ...@@ -345,7 +338,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
...@@ -358,12 +351,12 @@ struct ReductionHost ...@@ -358,12 +351,12 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include <iostream>
#include <sstream>
#include "device_base.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// FIXME: support arbitrary elementwise operation for A/B/C
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct ReferenceCGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_real_{a_m_k_real},
a_m_k_imag_{a_m_k_imag},
b_k_n_real_{b_k_n_real},
b_k_n_imag_{b_k_n_imag},
c_m_n_real_{c_m_n_real},
c_m_n_imag_{c_m_n_imag},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_real_;
const Tensor<ADataType>& a_m_k_imag_;
const Tensor<BDataType>& b_k_n_real_;
const Tensor<BDataType>& b_k_n_imag_;
Tensor<CDataType>& c_m_n_real_;
Tensor<CDataType>& c_m_n_imag_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceCGemm::Argument;
float Run(const Argument& arg)
{
const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
{
throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM");
}
auto f_mk_kn_mn_real = [&](auto m, auto n) {
float v_c_real = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag;
}
arg.c_m_n_real_(m, n) = v_c_real;
};
auto f_mk_kn_mn_imag = [&](auto m, auto n) {
float v_c_imag = 0;
for(std::size_t k = 0; k < K; ++k)
{
float v_a_real = ck::type_convert<float>(arg.a_m_k_real_(m, k));
float v_a_imag = ck::type_convert<float>(arg.a_m_k_imag_(m, k));
float v_b_real = ck::type_convert<float>(arg.b_k_n_real_(k, n));
float v_b_imag = ck::type_convert<float>(arg.b_k_n_imag_(k, n));
v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real;
}
arg.c_m_n_imag_(m, n) = v_c_imag;
};
make_ParallelTensorFunctor(f_mk_kn_mn_real,
arg.c_m_n_real_.mDesc.GetLengths()[0],
arg.c_m_n_real_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mk_kn_mn_imag,
arg.c_m_n_imag_.mDesc.GetLengths()[0],
arg.c_m_n_imag_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k_real,
const Tensor<ADataType>& a_m_k_imag,
const Tensor<BDataType>& b_k_n_real,
const Tensor<BDataType>& b_k_n_imag,
Tensor<CDataType>& c_m_n_real,
Tensor<CDataType>& c_m_n_imag,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k_real,
a_m_k_imag,
b_k_n_real,
b_k_n_imag,
c_m_n_real,
c_m_n_imag,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceCGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
#ifndef REFERENCE_CONV_WRW_HPP #pragma once
#define REFERENCE_CONV_WRW_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -16,7 +15,9 @@ template <typename InDataType, ...@@ -16,7 +15,9 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi}, : input_{in_n_c_hi_wi},
wei_k_c_y_x_{wei_k_c_y_x}, weight_{wei_k_c_y_x},
out_n_k_ho_wo_{out_n_k_ho_wo}, output_{out_n_k_ho_wo},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
} }
const Tensor<InDataType>& in_n_c_hi_wi_; const Tensor<InDataType>& input_;
Tensor<WeiDataType>& wei_k_c_y_x_; Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& out_n_k_ho_wo_; const Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
...@@ -66,59 +67,180 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -66,59 +67,180 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
constexpr auto I0 = Number<0>{}; if constexpr(NumDimSpatial == 1)
constexpr auto I1 = Number<1>{}; {
auto f_kcyx = [&](auto k, auto c, auto y, auto x) { constexpr auto I0 = Number<0>{};
float v_acc = 0; auto f_kcx = [&](auto k, auto c, auto x) {
for(std::size_t n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) float v_acc = 0;
{ for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
for(std::size_t ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho)
{ {
auto hi = ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo)
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo)
{ {
auto wi = auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) + ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I1]) - ck::type_convert<ck::long_index_t>(x * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]); ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
if(hi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(hi) < ck::type_convert<std::size_t>(wi) < arg.input_.mDesc.GetLengths()[2])
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
{ {
float v_out; float v_out;
float v_in; float v_in;
arg.out_element_op_( arg.out_element_op_(v_out,
v_out, ck::type_convert<float>(arg.output_(n, k, wo)));
ck::type_convert<float>(arg.out_n_k_ho_wo_(n, k, ho, wo))); arg.in_element_op_(v_in,
arg.in_element_op_( ck::type_convert<float>(arg.input_(n, c, wi)));
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += v_out * v_in;
} }
} }
} }
} float v_wei;
float v_wei;
arg.wei_element_op_(v_wei, v_acc); arg.wei_element_op_(v_wei, v_acc);
arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert<OutDataType>(v_wei); arg.weight_(k, c, x) = ck::type_convert<WeiDataType>(v_wei);
}; };
make_ParallelTensorFunctor(f_kcyx, make_ParallelTensorFunctor(f_kcx,
arg.wei_k_c_y_x_.mDesc.GetLengths()[0], arg.weight_.mDesc.GetLengths()[0],
arg.wei_k_c_y_x_.mDesc.GetLengths()[1], arg.weight_.mDesc.GetLengths()[1],
arg.wei_k_c_y_x_.mDesc.GetLengths()[2], arg.weight_.mDesc.GetLengths()[2])(
arg.wei_k_c_y_x_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency());
std::thread::hardware_concurrency());
return 0; return 0;
}
else if constexpr(NumDimSpatial == 2)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_kcyx = [&](auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(y * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(x *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[2] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[3])
{
float v_out;
float v_in;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(n, k, ho, wo)));
arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kcyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n)
{
for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_)
{
auto di =
ck::type_convert<ck::long_index_t>(do_ * arg.conv_strides_[I0]) +
ck::type_convert<ck::long_index_t>(z * arg.conv_dilations_[I0]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I0]);
for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho)
{
auto hi =
ck::type_convert<ck::long_index_t>(ho * arg.conv_strides_[I1]) +
ck::type_convert<ck::long_index_t>(y *
arg.conv_dilations_[I1]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I1]);
for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4];
++wo)
{
auto wi =
ck::type_convert<ck::long_index_t>(wo *
arg.conv_strides_[I2]) +
ck::type_convert<ck::long_index_t>(
x * arg.conv_dilations_[I2]) -
ck::type_convert<ck::long_index_t>(arg.in_left_pads_[I2]);
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.mDesc.GetLengths()[4])
{
float v_out;
float v_in;
arg.out_element_op_(v_out,
ck::type_convert<float>(
arg.output_(n, k, do_, ho, wo)));
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
v_acc += v_out * v_in;
}
}
}
}
}
float v_wei;
arg.wei_element_op_(v_wei, v_acc);
arg.weight_(k, c, z, y, x) = ck::type_convert<WeiDataType>(v_wei);
};
make_ParallelTensorFunctor(f_kczyx,
arg.weight_.mDesc.GetLengths()[0],
arg.weight_.mDesc.GetLengths()[1],
arg.weight_.mDesc.GetLengths()[2],
arg.weight_.mDesc.GetLengths()[3],
arg.weight_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
...@@ -182,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -182,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
} // namespace host } // namespace host
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -11,6 +11,7 @@ namespace host { ...@@ -11,6 +11,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -53,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -53,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
float v_a; AccDataType v_a;
float v_b; AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
......
...@@ -9,26 +9,11 @@ ...@@ -9,26 +9,11 @@
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp" #include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp" #include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp" #include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" #include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
......
...@@ -3,13 +3,27 @@ ...@@ -3,13 +3,27 @@
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp" #include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise.hpp" #include "device_reduce_multiblock.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
using reduce_configuration_1_instances_blockwise = std::tuple<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1<256, 128, 2>,
ReductionConfiguration_1<256, 64, 4>,
ReductionConfiguration_1<256, 32, 8>,
ReductionConfiguration_1<256, 16, 16>,
ReductionConfiguration_1<256, 8, 32>,
ReductionConfiguration_1<256, 4, 64>,
ReductionConfiguration_1<256, 2, 128>,
ReductionConfiguration_1<256, 1, 256>
// clang-format on
>;
#ifdef QUICK_REDUCE_TEST #ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise = std::tuple< using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off // clang-format off
...@@ -58,8 +72,8 @@ template <typename InDataType, ...@@ -58,8 +72,8 @@ template <typename InDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation NanOpt, bool PropagateNan,
ReduceTensorIndices IndicesOpt> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{ {
...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise( ...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); constexpr bool OutputIndex = Indexable && UseIndex;
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) { using cfg1 = remove_cvref_t<decltype(
using cfg1 = std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}( [&](auto j) {
[&](auto j) { using cfg2 = remove_cvref_t<decltype(
using cfg2 = remove_cvref_t<decltype( std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance =
using ReduceOpInstance = DeviceReduceBlockWise<InDataType, DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
PropagateNan, InMemoryDataOperationEnum::Set,
NeedIndices, PropagateNan,
cfg1::BlockSize_, OutputIndex,
cfg1::MThreadClusterSize_, false, // HaveIndexInputIfOutputIndex
cfg1::KThreadClusterSize_, cfg1::BlockSize_,
cfg2::MThreadSliceSize_, cfg1::MThreadClusterSize_,
cfg2::KThreadSliceSize_, cfg1::KThreadClusterSize_,
cfg2::InSrcVectorDim_, cfg2::MThreadSliceSize_,
cfg2::InSrcVectorSize_, cfg2::KThreadSliceSize_,
cfg2::OutDstVectorSize_>; cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
device_op_instances.push_back( cfg2::OutDstVectorSize_>;
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
}); device_op_instances.push_back(
}); std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE( \ #define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \ template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \ typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \ typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment