Commit 057ffb90 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Unify naming style in 'DevicePermute'

parent ee40f5a9
...@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator ...@@ -21,10 +21,10 @@ struct DevicePermute : BaseOperator
using Strides = Lengths; using Strides = Lengths;
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& inLengths, MakeArgumentPointer(const Lengths& in_lengths,
const Strides& inStrides, const Strides& in_strides,
const Lengths& outLengths, const Lengths& out_lengths,
const Strides& outStrides, const Strides& out_strides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0; ElementwiseOperation elementwise_op) = 0;
......
...@@ -60,7 +60,7 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen ...@@ -60,7 +60,7 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{}); return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
} }
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides stride) static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
{ {
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]] // d[NumDim-1]]
...@@ -111,21 +111,21 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen ...@@ -111,21 +111,21 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const Lengths& inLengths, Argument(const Lengths& in_lengths,
const Strides& inStrides, const Strides& in_strides,
const Lengths& outLengths, const Lengths& out_lengths,
const Strides& outStrides, const Strides& out_strides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
ElementwiseOperation elementwise_op) ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)), : in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)), out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(inLengths, inStrides)), in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
out_grid_desc_(MakeDescriptor_N_H_W(outLengths, outStrides)), out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
inLengths_(inLengths), in_lengths_(in_lengths),
inStrides_(inStrides), in_strides_(in_strides),
outLengths_(outLengths), out_lengths_(out_lengths),
outStrides_(outStrides), out_strides_(out_strides),
elementwise_op_(elementwise_op), elementwise_op_(elementwise_op),
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_)) block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
{ {
...@@ -136,10 +136,10 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen ...@@ -136,10 +136,10 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen
InGridDesc in_grid_desc_; InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_; OutGridDesc out_grid_desc_;
Lengths inLengths_; Lengths in_lengths_;
Strides inStrides_; Strides in_strides_;
Lengths outLengths_; Lengths out_lengths_;
Strides outStrides_; Strides out_strides_;
ElementwiseOperation elementwise_op_; ElementwiseOperation elementwise_op_;
...@@ -195,21 +195,21 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen ...@@ -195,21 +195,21 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen
return false; return false;
}; };
return IsScalarPerVectorValid(arg.inLengths_[SrcVectorDim], return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
arg.inStrides_[SrcVectorDim], arg.in_strides_[SrcVectorDim],
SrcScalarPerVector) && SrcScalarPerVector) &&
IsScalarPerVectorValid( IsScalarPerVectorValid(
GetPaddedLength(arg.inLengths_[SrcVectorDim], GetPaddedLength(arg.in_lengths_[SrcVectorDim],
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)), (SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[SrcVectorDim], arg.in_strides_[SrcVectorDim],
SrcScalarPerVector) && SrcScalarPerVector) &&
IsScalarPerVectorValid(arg.outLengths_[DstVectorDim], IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
arg.outStrides_[DstVectorDim], arg.out_strides_[DstVectorDim],
DstScalarPerVector) && DstScalarPerVector) &&
IsScalarPerVectorValid( IsScalarPerVectorValid(
GetPaddedLength(arg.outLengths_[DstVectorDim], GetPaddedLength(arg.out_lengths_[DstVectorDim],
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)), (DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.inStrides_[DstVectorDim], arg.in_strides_[DstVectorDim],
DstScalarPerVector) && DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_); GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
}; };
...@@ -228,18 +228,18 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen ...@@ -228,18 +228,18 @@ struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, Elemen
// override methods inherited from 'DevicePermute' // override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& inLengths, MakeArgumentPointer(const Lengths& in_lengths,
const Strides& inStrides, const Strides& in_strides,
const Lengths& outLengths, const Lengths& out_lengths,
const Strides& outStrides, const Strides& out_strides,
const void* in_dev_buffer, const void* in_dev_buffer,
void* out_dev_buffer, void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final ElementwiseOperation elementwise_op) override final
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(in_lengths,
inStrides, in_strides,
outLengths, out_lengths,
outStrides, out_strides,
in_dev_buffer, in_dev_buffer,
out_dev_buffer, out_dev_buffer,
elementwise_op); elementwise_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment