Unverified Commit e1a5137e authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into transpose_5d

parents eb57178d 718065eb
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
// FP16 // FP16
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&); std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
...@@ -27,7 +27,7 @@ void add_device_normalization_rank_4_3_f16_instances( ...@@ -27,7 +27,7 @@ void add_device_normalization_rank_4_3_f16_instances(
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
#endif #endif
#ifdef __fp32__ #ifdef CK_ENABLE_FP32
// FP32 // FP32
void add_device_normalization_rank_2_1_f32_instances( void add_device_normalization_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&); std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
...@@ -66,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -66,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>) is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
{ {
...@@ -84,7 +84,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -84,7 +84,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
} }
} }
#endif #endif
#ifdef __fp32__ #ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> && if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>) is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
{ {
......
...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3; ...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
// FP16 // FP16
void add_device_pool3d_fwd_ndhwc_f16_instances( void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
...@@ -37,7 +37,22 @@ void add_device_pool3d_fwd_ndhwc_index_f16_instances( ...@@ -37,7 +37,22 @@ void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&); DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif #endif
#ifdef __fp32__ #ifdef CK_ENABLE_BF16
// BF16
void add_device_pool3d_fwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// BF16 - return index
void add_device_pool3d_fwd_ndhwc_index_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32 // FP32
void add_device_pool3d_fwd_ndhwc_f32_instances( void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
...@@ -84,7 +99,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -84,7 +99,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>) if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{ {
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -98,9 +113,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -98,9 +113,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
} }
} }
#endif #endif
#ifdef __fp32__ #ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> && else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_bf16_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_bf16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{ {
if constexpr(OutputIndex && ReduceOpId == MaxOp) if constexpr(OutputIndex && ReduceOpId == MaxOp)
{ {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__ #ifdef CK_ENABLE_INT8
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__ #ifdef CK_ENABLE_INT8
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__ #ifdef CK_ENABLE_INT8
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__ #ifdef CK_ENABLE_INT8
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__ #ifdef CK_ENABLE_INT8
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
......
...@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise( ...@@ -89,13 +89,13 @@ void add_device_reduce_instance_blockwise(
{ {
static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) { [&](auto i) {
using cfg1 = remove_cvref_t<decltype( using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>; reduce_configuration_1_instances_blockwise{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
[&](auto j) { [&](auto j) {
using cfg2 = remove_cvref_t<decltype( using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>; reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance = using ReduceOpInstance =
DeviceReduceMultiBlock<InDataType, DeviceReduceMultiBlock<InDataType,
......
...@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -90,14 +90,14 @@ void add_device_reduce_instance_multiblock_atomic_add(
static_for<0, static_for<0,
std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value, std::tuple_size<reduce_configuration_1_instances_multiblock_atomic_add>::value,
1>{}([&](auto i) { 1>{}([&](auto i) {
using cfg1 = remove_cvref_t<decltype( using cfg1 = remove_cvref_t<decltype(std::get<i.value>(
std::get<i.value>(reduce_configuration_1_instances_multiblock_atomic_add{}))>; reduce_configuration_1_instances_multiblock_atomic_add{}))>;
static_for<0, static_for<0,
std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value, std::tuple_size<reduce_configuration_2_instances_multiblock_atomic_add>::value,
1>{}([&](auto j) { 1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype( using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
std::get<j.value>(reduce_configuration_2_instances_multiblock_atomic_add{}))>; reduce_configuration_2_instances_multiblock_atomic_add{}))>;
using ReduceOpInstance = DeviceReduceMultiBlock<InDataType, using ReduceOpInstance = DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
......
...@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise( ...@@ -77,8 +77,8 @@ void add_device_reduce_instance_threadwise(
static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}( static_for<0, std::tuple_size<reduce_configuration_2_instances_threadwise>::value, 1>{}(
[&](auto j) { [&](auto j) {
using cfg2 = remove_cvref_t<decltype( using cfg2 = remove_cvref_t<decltype(std::get<j.value>(
std::get<j.value>(reduce_configuration_2_instances_threadwise{}))>; reduce_configuration_2_instances_threadwise{}))>;
using ReduceOpInstance = DeviceReduceThreadWise<InDataType, using ReduceOpInstance = DeviceReduceThreadWise<InDataType,
AccDataType, AccDataType,
......
...@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -40,7 +40,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__ #ifdef CK_ENABLE_FP16
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> && if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F16>) std::is_same_v<OutDataType, F16>)
{ {
...@@ -66,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma ...@@ -66,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
} }
} }
#endif #endif
#ifdef __fp32__ #ifdef CK_ENABLE_FP32
if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> && if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>) std::is_same_v<OutDataType, F32>)
{ {
......
...@@ -65,7 +65,11 @@ check_err(const Range& out, ...@@ -65,7 +65,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
...@@ -112,7 +116,11 @@ check_err(const Range& out, ...@@ -112,7 +116,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
...@@ -158,7 +166,11 @@ check_err(const Range& out, ...@@ -158,7 +166,11 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
...@@ -209,10 +221,108 @@ check_err(const Range& out, ...@@ -209,10 +221,108 @@ check_err(const Range& out,
} }
if(!res) if(!res)
{ {
std::cerr << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
#if defined CK_ENABLE_FP8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
} }
#endif
#if defined CK_ENABLE_BF8
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
#endif
} // namespace utils } // namespace utils
} // namespace ck } // namespace ck
...@@ -26,7 +26,9 @@ struct DeviceMem ...@@ -26,7 +26,9 @@ struct DeviceMem
void* GetDeviceBuffer() const; void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const; void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const; void SetZero() const;
template <typename T> template <typename T>
void SetValue(T x) const; void SetValue(T x) const;
......
...@@ -102,9 +102,10 @@ struct FillMonotonicSeq ...@@ -102,9 +102,10 @@ struct FillMonotonicSeq
} }
template <typename ForwardRange> template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<decltype( auto operator()(ForwardRange&& range) const
std::declval<const FillMonotonicSeq&>()(std::begin(std::forward<ForwardRange>(range)), -> std::void_t<decltype(std::declval<const FillMonotonicSeq&>()(
std::end(std::forward<ForwardRange>(range))))> std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{ {
(*this)(std::begin(std::forward<ForwardRange>(range)), (*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))); std::end(std::forward<ForwardRange>(range)));
......
...@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t> ...@@ -95,6 +95,22 @@ struct GeneratorTensor_2<int8_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_2<ck::f8_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
return ck::type_convert<ck::f8_t>(tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_3 struct GeneratorTensor_3
{ {
...@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t> ...@@ -127,6 +143,25 @@ struct GeneratorTensor_3<ck::bhalf_t>
} }
}; };
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
template <>
struct GeneratorTensor_3<ck::f8_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = float(std::rand()) / float(RAND_MAX);
float fp32_tmp = min_value + tmp * (max_value - min_value);
return ck::type_convert<ck::f8_t>(fp32_tmp);
}
};
#endif
template <typename T> template <typename T>
struct GeneratorTensor_4 struct GeneratorTensor_4
{ {
......
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using I32 = int32_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using NDHWC = ck::tensor_layout::convolution::NDHWC;
using device_avgpool_bwd_ndhwc_f16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F16, F16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_bf16_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<BF16, BF16, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
using device_avgpool_bwd_ndhwc_f32_instances =
// clang-format off
std::tuple <
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool3dBwd_NDHWC_NDHWC<F32, F32, F32, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, BF16, BF16, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_bf16_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F16, F16, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f16_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "avg_pool3d_bwd_ndhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_bwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F32, F32, NDHWC, NDHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f32_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment