Commit d27c06a7 authored by Astha Rai's avatar Astha Rai
Browse files

updating scalar multiplication as an operator

parent c6b98c98
......@@ -19,16 +19,16 @@ using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
// ck::index_t scalar_mult = 2;
using Scale = ck::tensor_operation::element_wise::Scale;
//float scale = 1.f;
using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // ElementwiseOp
UnaryOp, // UnaryOp
Scale, // Scalar
4, // NumDim
8, // MPerThread
2, // ScalarMult (alpha)
ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq
......@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc,
ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 2 * tmp_val);
functor_a(B_nhwc(n, h, w, c), 1 * tmp_val);
}
}
......@@ -59,7 +59,7 @@ int main()
std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
float scale = 1.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......@@ -84,7 +84,7 @@ int main()
auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{});
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{}, Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get()))
{
......
......@@ -17,6 +17,7 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim>
struct DeviceElementwise : public BaseOperator
{
......@@ -30,7 +31,8 @@ struct DeviceElementwise : public BaseOperator
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op) = 0;
UnaryOperation unary_op,
Scale scale_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device
......@@ -39,11 +41,13 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim>
using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>>;
} // namespace device
......
......@@ -23,15 +23,16 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t NumDim,
index_t MPerThread,
index_t ScalarMult,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
NumDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
......@@ -134,8 +135,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
UnaryOperation,
Scale,
MPerThread,
ScalarMult,
InScalarPerVectorSeq,
OutScalarPerVectorSeq>;
......@@ -147,13 +148,15 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op)
UnaryOperation unary_op,
Scale scale_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op),
unary_op_(unary_op),
scale_op_(scale_op),
blockSize_(256)
{
in_dev_buffers_ = generate_tuple(
......@@ -180,6 +183,7 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
ElementwiseOperation elementwise_op_;
UnaryOperation unary_op_;
Scale scale_op_;
index_t blockSize_;
};
......@@ -209,7 +213,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
ElementwiseOperation,
UnaryOperation>;
UnaryOperation,
Scale>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -221,7 +226,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
arg.elementwise_op_,
arg.unary_op_);
arg.unary_op_,
arg.scale_op_);
return elapsed_time;
}
......@@ -278,7 +284,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op)
UnaryOperation unary_op,
Scale scale_op)
{
return Argument{lengths,
inStridesArray,
......@@ -286,7 +293,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
in_dev_buffers,
out_dev_buffers,
elementwise_op,
unary_op};
unary_op,
scale_op};
}
std::unique_ptr<BaseArgument>
......@@ -296,7 +304,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op,
UnaryOperation unary_op) override
UnaryOperation unary_op,
Scale scale_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
......@@ -304,7 +313,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
in_dev_buffers,
out_dev_buffers,
elementwise_op,
unary_op);
unary_op,
scale_op);
}
static auto MakeInvoker() { return Invoker{}; }
......
......@@ -21,20 +21,23 @@ template <typename GridwiseElementwise1dFunctor,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation>
typename UnaryOperation,
typename Scale>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op)
const UnaryOperation unary_op,
const Scale scale_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
unary_op);
unary_op,
scale_op);
}
template <typename InGrid1dDescTuple,
......@@ -43,8 +46,8 @@ template <typename InGrid1dDescTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
typename UnaryOperation,
typename Scale,
index_t MPerThread,
index_t ScalarMult,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D
......@@ -70,7 +73,8 @@ struct GridwiseElementwise_1D
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op)
const UnaryOperation unary_op,
const Scale scale_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
......@@ -83,15 +87,6 @@ struct GridwiseElementwise_1D
},
Number<NumInput>{});
auto tmp_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
......@@ -168,7 +163,7 @@ struct GridwiseElementwise_1D
},
Number<NumOutput>{});
const auto& scalar = ScalarMult;
//const auto& scalar = ScalarMult;
index_t num_iter = M / (loop_step);
do
{
......@@ -183,17 +178,8 @@ struct GridwiseElementwise_1D
loop_step_index);
});
// static_for<0, MPerThread, 1>{}(
// [&](auto I){
// InDataTypePointerTuple tmp;
// unary_op(in_thread_buf_tuple(I), in_thread_buf_tuple(I));
// in_thread_buf_tuple(I) = tmp;
//});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// tmp_thread_buf_tuple = [&](auto I){ unary_op(in_thread_buf_tuple(I)(iM),
// in_thread_buf_tuple(I)(iM)); }; unary_op(in_thread_buf_tuple(iM),
// in_thread_buf_tuple(iM));
// get reference to in data
auto uop_data_refs = generate_tie(
// return type should be lvalue
......@@ -208,13 +194,25 @@ struct GridwiseElementwise_1D
unpack2(unary_op, uop_data_refs, uop_data_refs);
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM) *= scalar; },
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
UNUSED(tmp_thread_buf_tuple);
UNUSED(scale_op);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment