Commit c6b98c98 authored by Astha Rai's avatar Astha Rai
Browse files

added fp 16 type check in unary square

parent aa61ccf0
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl_ht.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -18,25 +18,35 @@ using ADataType = F16; ...@@ -18,25 +18,35 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
// ck::index_t scalar_mult = 2;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, PassThrough, // ElementwiseOp
4, UnaryOp, // UnaryOp
8, 4, // NumDim
ck::Sequence<8>, 8, // MPerThread
ck::Sequence<1>>; 2, // ScalarMult (alpha)
ck::Sequence<8>, // InScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename Functor> ck::Sequence<1>>; // OutScalarPerVectorSeq
void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor)
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
FunctorA functor_a,
FunctorB functor_b)
{ {
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h)
for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w)
{ {
ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw(n, c, h, w);
functor(B_nhwc(n, h, w, c), a_val); functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 2 * tmp_val);
} }
} }
...@@ -74,7 +84,7 @@ int main() ...@@ -74,7 +84,7 @@ int main()
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -106,7 +116,7 @@ int main() ...@@ -106,7 +116,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}); host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{});
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -18,19 +18,19 @@ using ADataType = F32; ...@@ -18,19 +18,19 @@ using ADataType = F32;
using BDataType = F32; using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare; using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
// ck::index_t scalar_mult = 2; // ck::index_t scalar_mult = 2;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, PassThrough, // ElementwiseOp
Square, UnaryOp, // UnaryOp
4, 4, // NumDim
8, 8, // MPerThread
2, 2, // ScalarMult (alpha)
ck::Sequence<8>, ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<1>>; ck::Sequence<1>>; // OutScalarPerVectorSeq
template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB> template <typename HostTensorA, typename HostTensorB, typename FunctorA, typename FunctorB>
void host_elementwise4D(HostTensorB& B_nhwc, void host_elementwise4D(HostTensorB& B_nhwc,
...@@ -84,7 +84,7 @@ int main() ...@@ -84,7 +84,7 @@ int main()
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, Square{}); ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -116,7 +116,7 @@ int main() ...@@ -116,7 +116,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, Square{}); host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{});
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -278,8 +278,8 @@ struct UnarySquare ...@@ -278,8 +278,8 @@ struct UnarySquare
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> || static_assert(is_same_v<T, float> || is_same_v<T, half_t> || is_same_v<T, double> ||
is_same_v<T, int8_t> is_same_v<T, int32_t> || is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t> || is_same_v<T, int4_t>
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment