Commit faaa6637 authored by Astha Rai's avatar Astha Rai
Browse files

functioning version with scalar operator

parent d27c06a7
...@@ -20,7 +20,6 @@ using BDataType = F16; ...@@ -20,7 +20,6 @@ using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
//float scale = 1.f;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple ck::Tuple<BDataType>, // OutDataTypeTuple
...@@ -36,7 +35,8 @@ template <typename HostTensorA, typename HostTensorB, typename FunctorA, typenam ...@@ -36,7 +35,8 @@ template <typename HostTensorA, typename HostTensorB, typename FunctorA, typenam
void host_elementwise4D(HostTensorB& B_nhwc, void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw, const HostTensorA& A_nchw,
FunctorA functor_a, FunctorA functor_a,
FunctorB functor_b) FunctorB functor_b,
float scale)
{ {
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
...@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc,
ADataType tmp_val; ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 1 * tmp_val); functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
} }
} }
...@@ -59,7 +59,7 @@ int main() ...@@ -59,7 +59,7 @@ int main()
std::vector<std::size_t> nhwc = {16, 32, 64, 128}; std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 1.f; float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
...@@ -83,8 +83,14 @@ int main() ...@@ -83,8 +83,14 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin()); ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{}, Scale{scale}); {a_strides},
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -116,7 +122,7 @@ int main() ...@@ -116,7 +122,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}); host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -19,16 +19,15 @@ using BDataType = F32; ...@@ -19,16 +19,15 @@ using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
// ck::index_t scalar_mult = 2; using Scale = ck::tensor_operation::element_wise::Scale;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // ElementwiseOp PassThrough, // ElementwiseOp
UnaryOp, // UnaryOp UnaryOp, // UnaryOp
Scale, // Scalar
4, // NumDim 4, // NumDim
8, // MPerThread 8, // MPerThread
2, // ScalarMult (alpha)
ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq
...@@ -36,7 +35,8 @@ template <typename HostTensorA, typename HostTensorB, typename FunctorA, typenam ...@@ -36,7 +35,8 @@ template <typename HostTensorA, typename HostTensorB, typename FunctorA, typenam
void host_elementwise4D(HostTensorB& B_nhwc, void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw, const HostTensorA& A_nchw,
FunctorA functor_a, FunctorA functor_a,
FunctorB functor_b) FunctorB functor_b,
float scale)
{ {
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
...@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc,
ADataType tmp_val; ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 2 * tmp_val); functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
} }
} }
...@@ -59,7 +59,7 @@ int main() ...@@ -59,7 +59,7 @@ int main()
std::vector<std::size_t> nhwc = {16, 32, 64, 128}; std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
...@@ -83,8 +83,14 @@ int main() ...@@ -83,8 +83,14 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin()); ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{}); {a_strides},
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
...@@ -116,7 +122,7 @@ int main() ...@@ -116,7 +122,7 @@ int main()
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}); host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
...@@ -225,6 +225,12 @@ struct Scale ...@@ -225,6 +225,12 @@ struct Scale
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = scale_ * x;
};
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
......
...@@ -163,7 +163,7 @@ struct GridwiseElementwise_1D ...@@ -163,7 +163,7 @@ struct GridwiseElementwise_1D
}, },
Number<NumOutput>{}); Number<NumOutput>{});
//const auto& scalar = ScalarMult; // const auto& scalar = ScalarMult;
index_t num_iter = M / (loop_step); index_t num_iter = M / (loop_step);
do do
{ {
...@@ -178,7 +178,6 @@ struct GridwiseElementwise_1D ...@@ -178,7 +178,6 @@ struct GridwiseElementwise_1D
loop_step_index); loop_step_index);
}); });
static_for<0, MPerThread, 1>{}([&](auto iM) { static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data // get reference to in data
auto uop_data_refs = generate_tie( auto uop_data_refs = generate_tie(
...@@ -196,7 +195,7 @@ struct GridwiseElementwise_1D ...@@ -196,7 +195,7 @@ struct GridwiseElementwise_1D
auto sop_in_data_refs = generate_tie( auto sop_in_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); }, [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{}); Number<NumInput>{});
auto sop_out_data_refs = generate_tie( auto sop_out_data_refs = generate_tie(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment