"git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "af436f5951d25e0ed0f6ec66431f39ce6a02c801"
Commit d27c06a7 authored by Astha Rai's avatar Astha Rai
Browse files

updating scalar multiplication as an operator

parent c6b98c98
...@@ -19,16 +19,16 @@ using BDataType = F16; ...@@ -19,16 +19,16 @@ using BDataType = F16;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; using UnaryOp = ck::tensor_operation::element_wise::UnarySquare;
// ck::index_t scalar_mult = 2; using Scale = ck::tensor_operation::element_wise::Scale;
//float scale = 1.f;
using DeviceElementwisePermuteInstance = using DeviceElementwisePermuteInstance =
ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple ck::tensor_operation::device::DeviceElementwiseImpl<ck::Tuple<ADataType>, // InDataTypeTuple
ck::Tuple<BDataType>, // OutDataTypeTuple ck::Tuple<BDataType>, // OutDataTypeTuple
PassThrough, // ElementwiseOp PassThrough, // ElementwiseOp
UnaryOp, // UnaryOp UnaryOp, // UnaryOp
Scale, // Scalar
4, // NumDim 4, // NumDim
8, // MPerThread 8, // MPerThread
2, // ScalarMult (alpha)
ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>, // InScalarPerVectorSeq
ck::Sequence<1>>; // OutScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq
...@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -46,7 +46,7 @@ void host_elementwise4D(HostTensorB& B_nhwc,
ADataType tmp_val; ADataType tmp_val;
auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw(n, c, h, w);
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
functor_a(B_nhwc(n, h, w, c), 2 * tmp_val); functor_a(B_nhwc(n, h, w, c), 1 * tmp_val);
} }
} }
...@@ -59,7 +59,7 @@ int main() ...@@ -59,7 +59,7 @@ int main()
std::vector<std::size_t> nhwc = {16, 32, 64, 128}; std::vector<std::size_t> nhwc = {16, 32, 64, 128};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 1.f;
a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
...@@ -84,7 +84,7 @@ int main() ...@@ -84,7 +84,7 @@ int main()
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{}); ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}, UnaryOp{}, Scale{scale});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!broadcastPermute.IsSupportedArgument(argument.get()))
{ {
......
...@@ -17,6 +17,7 @@ template <typename InDataTypeTuple, ...@@ -17,6 +17,7 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
typename UnaryOperation, typename UnaryOperation,
typename Scale,
index_t NumDim> index_t NumDim>
struct DeviceElementwise : public BaseOperator struct DeviceElementwise : public BaseOperator
{ {
...@@ -30,7 +31,8 @@ struct DeviceElementwise : public BaseOperator ...@@ -30,7 +31,8 @@ struct DeviceElementwise : public BaseOperator
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers, const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op, ElementwiseOperation elementwise_op,
UnaryOperation unary_op) = 0; UnaryOperation unary_op,
Scale scale_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; // namespace device }; // namespace device
...@@ -39,11 +41,13 @@ template <typename InDataTypeTuple, ...@@ -39,11 +41,13 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
typename UnaryOperation, typename UnaryOperation,
typename Scale,
index_t NumDim> index_t NumDim>
using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple, using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple, OutDataTypeTuple,
ElementwiseOperation, ElementwiseOperation,
UnaryOperation, UnaryOperation,
Scale,
NumDim>>; NumDim>>;
} // namespace device } // namespace device
......
...@@ -23,15 +23,16 @@ template <typename InDataTypeTuple, ...@@ -23,15 +23,16 @@ template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
typename UnaryOperation, typename UnaryOperation,
typename Scale,
index_t NumDim, index_t NumDim,
index_t MPerThread, index_t MPerThread,
index_t ScalarMult,
typename InScalarPerVectorSeq, typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq> typename OutScalarPerVectorSeq>
struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple, OutDataTypeTuple,
ElementwiseOperation, ElementwiseOperation,
UnaryOperation, UnaryOperation,
Scale,
NumDim> NumDim>
{ {
static constexpr int NumInput = InDataTypeTuple::Size(); static constexpr int NumInput = InDataTypeTuple::Size();
...@@ -134,8 +135,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -134,8 +135,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypePointerTuple, OutDataTypePointerTuple,
ElementwiseOperation, ElementwiseOperation,
UnaryOperation, UnaryOperation,
Scale,
MPerThread, MPerThread,
ScalarMult,
InScalarPerVectorSeq, InScalarPerVectorSeq,
OutScalarPerVectorSeq>; OutScalarPerVectorSeq>;
...@@ -147,13 +148,15 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -147,13 +148,15 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers, const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op, ElementwiseOperation elementwise_op,
UnaryOperation unary_op) UnaryOperation unary_op,
Scale scale_op)
: lengths_(lengths), : lengths_(lengths),
inStridesArray_(inStridesArray), inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray), outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
unary_op_(unary_op), unary_op_(unary_op),
scale_op_(scale_op),
blockSize_(256) blockSize_(256)
{ {
in_dev_buffers_ = generate_tuple( in_dev_buffers_ = generate_tuple(
...@@ -180,6 +183,7 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -180,6 +183,7 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
UnaryOperation unary_op_; UnaryOperation unary_op_;
Scale scale_op_;
index_t blockSize_; index_t blockSize_;
}; };
...@@ -209,7 +213,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -209,7 +213,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
InDataTypePointerTuple, InDataTypePointerTuple,
OutDataTypePointerTuple, OutDataTypePointerTuple,
ElementwiseOperation, ElementwiseOperation,
UnaryOperation>; UnaryOperation,
Scale>;
float elapsed_time = launch_and_time_kernel(stream_config, float elapsed_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -221,7 +226,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -221,7 +226,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
arg.in_dev_buffers_, arg.in_dev_buffers_,
arg.out_dev_buffers_, arg.out_dev_buffers_,
arg.elementwise_op_, arg.elementwise_op_,
arg.unary_op_); arg.unary_op_,
arg.scale_op_);
return elapsed_time; return elapsed_time;
} }
...@@ -278,7 +284,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -278,7 +284,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers, const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op, ElementwiseOperation elementwise_op,
UnaryOperation unary_op) UnaryOperation unary_op,
Scale scale_op)
{ {
return Argument{lengths, return Argument{lengths,
inStridesArray, inStridesArray,
...@@ -286,7 +293,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -286,7 +293,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
in_dev_buffers, in_dev_buffers,
out_dev_buffers, out_dev_buffers,
elementwise_op, elementwise_op,
unary_op}; unary_op,
scale_op};
} }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -296,7 +304,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -296,7 +304,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers, const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op, ElementwiseOperation elementwise_op,
UnaryOperation unary_op) override UnaryOperation unary_op,
Scale scale_op) override
{ {
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
inStridesArray, inStridesArray,
...@@ -304,7 +313,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -304,7 +313,8 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
in_dev_buffers, in_dev_buffers,
out_dev_buffers, out_dev_buffers,
elementwise_op, elementwise_op,
unary_op); unary_op,
scale_op);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -21,20 +21,23 @@ template <typename GridwiseElementwise1dFunctor, ...@@ -21,20 +21,23 @@ template <typename GridwiseElementwise1dFunctor,
typename InDataTypePointerTuple, typename InDataTypePointerTuple,
typename OutDataTypePointerTuple, typename OutDataTypePointerTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
typename UnaryOperation> typename UnaryOperation,
typename Scale>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, __global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple, const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op, const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op) const UnaryOperation unary_op,
const Scale scale_op)
{ {
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple, GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple, out_grid_1d_desc_tuple,
p_in_global_tuple, p_in_global_tuple,
p_out_global_tuple, p_out_global_tuple,
elementwise_op, elementwise_op,
unary_op); unary_op,
scale_op);
} }
template <typename InGrid1dDescTuple, template <typename InGrid1dDescTuple,
...@@ -43,8 +46,8 @@ template <typename InGrid1dDescTuple, ...@@ -43,8 +46,8 @@ template <typename InGrid1dDescTuple,
typename OutDataTypePointerTuple, typename OutDataTypePointerTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
typename UnaryOperation, typename UnaryOperation,
typename Scale,
index_t MPerThread, index_t MPerThread,
index_t ScalarMult,
typename InScalarPerVectorSeq, typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq> typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D struct GridwiseElementwise_1D
...@@ -70,7 +73,8 @@ struct GridwiseElementwise_1D ...@@ -70,7 +73,8 @@ struct GridwiseElementwise_1D
const InDataTypePointerTuple p_in_global_tuple, const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple, const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op, const ElementwiseOperation elementwise_op,
const UnaryOperation unary_op) const UnaryOperation unary_op,
const Scale scale_op)
{ {
const index_t thread_global_id = get_thread_global_1d_id(); const index_t thread_global_id = get_thread_global_1d_id();
...@@ -83,15 +87,6 @@ struct GridwiseElementwise_1D ...@@ -83,15 +87,6 @@ struct GridwiseElementwise_1D
}, },
Number<NumInput>{}); Number<NumInput>{});
auto tmp_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple( auto out_thread_buf_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>; using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
...@@ -168,7 +163,7 @@ struct GridwiseElementwise_1D ...@@ -168,7 +163,7 @@ struct GridwiseElementwise_1D
}, },
Number<NumOutput>{}); Number<NumOutput>{});
const auto& scalar = ScalarMult; //const auto& scalar = ScalarMult;
index_t num_iter = M / (loop_step); index_t num_iter = M / (loop_step);
do do
{ {
...@@ -183,17 +178,8 @@ struct GridwiseElementwise_1D ...@@ -183,17 +178,8 @@ struct GridwiseElementwise_1D
loop_step_index); loop_step_index);
}); });
// static_for<0, MPerThread, 1>{}(
// [&](auto I){
// InDataTypePointerTuple tmp;
// unary_op(in_thread_buf_tuple(I), in_thread_buf_tuple(I));
// in_thread_buf_tuple(I) = tmp;
//});
static_for<0, MPerThread, 1>{}([&](auto iM) { static_for<0, MPerThread, 1>{}([&](auto iM) {
// tmp_thread_buf_tuple = [&](auto I){ unary_op(in_thread_buf_tuple(I)(iM),
// in_thread_buf_tuple(I)(iM)); }; unary_op(in_thread_buf_tuple(iM),
// in_thread_buf_tuple(iM));
// get reference to in data // get reference to in data
auto uop_data_refs = generate_tie( auto uop_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
...@@ -208,13 +194,25 @@ struct GridwiseElementwise_1D ...@@ -208,13 +194,25 @@ struct GridwiseElementwise_1D
unpack2(unary_op, uop_data_refs, uop_data_refs); unpack2(unary_op, uop_data_refs, uop_data_refs);
auto sop_in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
auto sop_out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
const auto in_data_refs = generate_tie( const auto in_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM) *= scalar; }, [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{}); Number<NumInput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs); unpack2(elementwise_op, out_data_refs, in_data_refs);
UNUSED(tmp_thread_buf_tuple); UNUSED(scale_op);
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment