Commit faaa6637 authored by Astha Rai's avatar Astha Rai
Browse files

functioning version with scalar operator

parent d27c06a7
......@@ -20,13 +20,12 @@ using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
using Scale = ck::tensor_operation::element_wise::Scale;
//float scale = 1.f;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // ElementwiseOp
UnaryOp, // UnaryOp
Scale, // Scalar
Scale, // Scalar
4, // NumDim
8, // MPerThread
ck::Sequence<8>, // InScalarPerVectorSeq
......@@ -36,7 +35,8 @@ template <typename HostTensorA, typename HostTensorB, typename FunctorA, typenam
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
FunctorA functor_a,
FunctorB functor_b)
FunctorB functor_b,
float scale)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
......@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc,
ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 1 * tmp_val);
functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
}
}
......@@ -59,7 +59,7 @@ int main()
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 1.f;
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......@@ -83,8 +83,14 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{}, Scale{scale});
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
......@@ -116,7 +122,7 @@ int main()
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{});
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
......@@ -19,16 +19,15 @@ using BDataType = F32;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
// ck::index_t scalar_mult = 2;
using Scale = ck::tensor_operation::element_wise::Scale;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // ElementwiseOp
UnaryOp, // UnaryOp
Scale, // Scalar
4, // NumDim
8, // MPerThread
2, // ScalarMult (alpha)
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
......@@ -36,7 +35,8 @@ template <typename HostTensorA, typename HostTensorB, typename FunctorA, typenam
void host_elementwise4D(HostTensorB& B_nhwc,
const HostTensorA& A_nchw,
FunctorA functor_a,
FunctorB functor_b)
FunctorB functor_b,
float scale)
{
for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n)
for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c)
......@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc,
ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 2 * tmp_val);
functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
}
}
......@@ -59,7 +59,7 @@ int main()
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 2.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......@@ -83,8 +83,14 @@ int main()
ck::ranges::copy(nchw, ab_lengths.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{});
auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths,
{a_strides},
{b_strides},
input,
output,
PassThrough{},
UnaryOp{},
Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
......@@ -116,7 +122,7 @@ int main()
{
b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{});
host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale);
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
......@@ -225,6 +225,12 @@ struct Scale
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = scale_ * x;
};
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......
......@@ -22,14 +22,14 @@ template <typename GridwiseElementwise1dFunctor,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale>
typename Scale>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
const Scale scale_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
......@@ -37,7 +37,7 @@ __global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tu
p_out_global_tuple,
elementwise_op,
unary_op,
scale_op);
scale_op);
}
template <typename InGrid1dDescTuple,
......@@ -46,7 +46,7 @@ template <typename InGrid1dDescTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
typename Scale,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
......@@ -74,7 +74,7 @@ struct GridwiseElementwise_1D
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op,
const Scale scale_op)
const Scale scale_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
......@@ -163,8 +163,8 @@ struct GridwiseElementwise_1D
},
Number<NumOutput>{});
//const auto& scalar = ScalarMult;
index_t num_iter = M / (loop_step);
// const auto& scalar = ScalarMult;
index_t num_iter = M / (loop_step);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
......@@ -178,7 +178,6 @@ struct GridwiseElementwise_1D
loop_step_index);
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data
auto uop_data_refs = generate_tie(
......@@ -194,17 +193,17 @@ struct GridwiseElementwise_1D
unpack2(unary_op, uop_data_refs, uop_data_refs);
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
const auto in_data_refs = generate_tie(
// return type should be lvalue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment