Commit da881f4d authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: change meta within ElementwiseInfo to std::vector<size_t> for...

issue/127: change meta within ElementwiseInfo to std::vector<size_t> for correct alignment and change the reference name of the Opaque struct to Opaque instead of struct Opaque
parent fe9c4aa5
...@@ -38,7 +38,7 @@ namespace op::elementwise::cpu { ...@@ -38,7 +38,7 @@ namespace op::elementwise::cpu {
*/ */
class DeviceImpl final { class DeviceImpl final {
struct Opaque; struct Opaque;
std::shared_ptr<struct Opaque> _opaque; std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {} DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
......
...@@ -10,7 +10,7 @@ namespace op::elementwise::cuda { ...@@ -10,7 +10,7 @@ namespace op::elementwise::cuda {
*/ */
class DeviceImpl final { class DeviceImpl final {
struct Opaque; struct Opaque;
std::shared_ptr<struct Opaque> _opaque; std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {} DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
......
...@@ -68,13 +68,13 @@ namespace op::elementwise { ...@@ -68,13 +68,13 @@ namespace op::elementwise {
*/ */
struct ElementwiseInfo { struct ElementwiseInfo {
private: private:
std::vector<int8_t> _meta; std::vector<size_t> _meta;
size_t _output_size; size_t _output_size;
size_t _input_size; size_t _input_size;
size_t _ndim; size_t _ndim;
bool _output_contiguous; bool _output_contiguous;
ElementwiseInfo(std::vector<int8_t> meta, ElementwiseInfo(std::vector<size_t> meta,
size_t output_size, size_t output_size,
size_t input_size, size_t input_size,
size_t ndim, size_t ndim,
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
return _meta.size(); return _meta.size();
} }
inline const int8_t *getMetaStart() const { inline const int8_t *getMetaStart() const {
return _meta.data(); return reinterpret_cast<const int8_t *>(_meta.data());
} }
inline size_t getOutputSize() const { inline size_t getOutputSize() const {
return _output_size; return _output_size;
...@@ -167,8 +167,8 @@ public: ...@@ -167,8 +167,8 @@ public:
+ input_size * ndim * sizeof(shape_unit) + input_size * ndim * sizeof(shape_unit)
+ input_size * ndim * sizeof(stride_unit) + input_size * ndim * sizeof(stride_unit)
+ 2 * input_size * sizeof(bool); + 2 * input_size * sizeof(bool);
std::vector<int8_t> meta(meta_mem_size); std::vector<size_t> meta(meta_mem_size);
int8_t *meta_ptr = meta.data(); int8_t *meta_ptr = reinterpret_cast<int8_t *>(meta.data());
const auto output_shape = output_desc->shape(); const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides(); const auto output_strides = output_desc->strides();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment