Unverified Commit bd7a1466 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #195 from InfiniTensor/issue/194_fix_elementwise_overallocation

issue/194/fix: Fix ElementwiseInfo Meta Allocation
parents 62b7dd01 85bff3ae
...@@ -102,7 +102,7 @@ struct DeviceImpl::Opaque {}; ...@@ -102,7 +102,7 @@ struct DeviceImpl::Opaque {};
template <typename... Args> template <typename... Args>
utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) { utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) {
return utils::Result<DeviceImpl>(nullptr); return INFINI_STATUS_NOT_IMPLEMENTED;
} }
// Perform elementwise operation for different input types // Perform elementwise operation for different input types
......
...@@ -84,8 +84,9 @@ private: ...@@ -84,8 +84,9 @@ private:
_output_contiguous(output_contiguous) {} _output_contiguous(output_contiguous) {}
public: public:
// Get the Memory size of the meta data in bytes
inline size_t getMetaMemSize() const { inline size_t getMetaMemSize() const {
return _meta.size(); return _meta.size() * sizeof(size_t);
} }
inline const int8_t *getMetaStart() const { inline const int8_t *getMetaStart() const {
return reinterpret_cast<const int8_t *>(_meta.data()); return reinterpret_cast<const int8_t *>(_meta.data());
...@@ -167,7 +168,7 @@ public: ...@@ -167,7 +168,7 @@ public:
+ input_size * ndim * sizeof(shape_unit) + input_size * ndim * sizeof(shape_unit)
+ input_size * ndim * sizeof(stride_unit) + input_size * ndim * sizeof(stride_unit)
+ 2 * input_size * sizeof(bool); + 2 * input_size * sizeof(bool);
std::vector<size_t> meta(meta_mem_size); std::vector<size_t> meta(CEIL_DIV(meta_mem_size, sizeof(size_t)));
int8_t *meta_ptr = reinterpret_cast<int8_t *>(meta.data()); int8_t *meta_ptr = reinterpret_cast<int8_t *>(meta.data());
const auto output_shape = output_desc->shape(); const auto output_shape = output_desc->shape();
......
...@@ -14,7 +14,7 @@ private: ...@@ -14,7 +14,7 @@ private:
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
return __frcp_rd(__fadd_rd(1, __expf(-x))); return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else { } else {
return 1 / (1 + std::exp(-x)); return 1 / (1 + std::exp(-x));
} }
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up); return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up); return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else { } else {
return gate * sigmoid(gate) * up; return gate * sigmoid(gate) * up;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment