Commit b3275d7c authored by wooway777's avatar wooway777
Browse files

issue/21 - Slightly improved key creation process

parent e641693d
...@@ -10,30 +10,6 @@ ...@@ -10,30 +10,6 @@
#include "../utils.hpp" #include "../utils.hpp"
#include "infinicore_infer.h" #include "infinicore_infer.h"
// Hash combine utility (similar to boost::hash_combine)
inline void hash_combine(size_t &seed, size_t value) {
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// Specialization for enum types
template <typename T>
inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_enum<T>::value>::type * = 0) {
hash_combine(seed, static_cast<size_t>(value));
}
// Helper function to compute hash for tensor descriptors
inline size_t computeTensorDescHash(std::shared_ptr<Tensor> &tensor) {
size_t seed = 0;
hash_combine(seed, tensor->dtype());
for (auto dim : tensor->shape()) {
hash_combine(seed, dim);
}
for (auto stride : tensor->strides()) {
hash_combine(seed, static_cast<size_t>(stride));
}
return seed;
}
class IDescriptorDestroyer { class IDescriptorDestroyer {
public: public:
virtual ~IDescriptorDestroyer() = default; virtual ~IDescriptorDestroyer() = default;
...@@ -260,7 +236,7 @@ public: ...@@ -260,7 +236,7 @@ public:
template <typename... Tensors> template <typename... Tensors>
static size_t createDescriptorKey(Tensors... tensors) { static size_t createDescriptorKey(Tensors... tensors) {
size_t seed = 0; size_t seed = 0;
(..., (tensors ? hash_combine(seed, computeTensorDescHash(tensors)) : (void)0)); (..., (tensors ? hash_combine(seed, tensors->seed()) : (void)0));
return seed; return seed;
} }
}; };
......
...@@ -51,10 +51,12 @@ private: ...@@ -51,10 +51,12 @@ private:
std::vector<size_t> _shape; std::vector<size_t> _shape;
std::vector<ptrdiff_t> _strides; std::vector<ptrdiff_t> _strides;
infiniopTensorDescriptor_t _desc; infiniopTensorDescriptor_t _desc;
size_t _seed;
TensorDesc(infiniDtype_t dtype, const std::vector<size_t> &shape, TensorDesc(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) {} const std::vector<ptrdiff_t> &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) { computeTensorDesHash(); }
void resetDesc(); void resetDesc();
void computeTensorDesHash();
public: public:
~TensorDesc(); ~TensorDesc();
...@@ -74,6 +76,7 @@ public: ...@@ -74,6 +76,7 @@ public:
infiniopTensorDescriptor_t desc() const; infiniopTensorDescriptor_t desc() const;
bool isContigous() const; bool isContigous() const;
std::string info() const; std::string info() const;
size_t seed() const { return _seed; }
void dimMerge(size_t dim_start, size_t dim_end); void dimMerge(size_t dim_start, size_t dim_end);
void dimSplit(size_t dim, const std::vector<size_t> &dims); void dimSplit(size_t dim, const std::vector<size_t> &dims);
...@@ -127,10 +130,11 @@ public: ...@@ -127,10 +130,11 @@ public:
void debug(const std::string &filename) const; void debug(const std::string &filename) const;
void debug() const; void debug() const;
std::string info() const; std::string info() const;
size_t seed() const;
std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const; std::shared_ptr<Tensor> view(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape) const; std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape) const;
std::shared_ptr<Tensor> view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const;
~Tensor(); ~Tensor();
}; };
......
...@@ -62,6 +62,16 @@ void TensorDesc::resetDesc() { ...@@ -62,6 +62,16 @@ void TensorDesc::resetDesc() {
} }
} }
void TensorDesc::computeTensorDesHash() {
_seed = 0;
for (auto dim : this->shape()) {
hash_combine(_seed, dim);
}
for (auto stride : this->strides()) {
hash_combine(_seed, static_cast<size_t>(stride));
}
}
bool TensorDesc::isContigous() const { bool TensorDesc::isContigous() const {
auto ndim = this->ndim(); auto ndim = this->ndim();
auto shape = this->shape(); auto shape = this->shape();
...@@ -258,6 +268,10 @@ std::string Tensor::info() const { ...@@ -258,6 +268,10 @@ std::string Tensor::info() const {
return this->_desc->info(); return this->_desc->info();
} }
size_t Tensor::seed() const {
return this->_desc->seed();
}
std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const { std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const {
// Calculate total number of elements // Calculate total number of elements
size_t numel = 1; size_t numel = 1;
...@@ -383,18 +397,18 @@ std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const ...@@ -383,18 +397,18 @@ std::shared_ptr<Tensor> Tensor::view(const std::vector<size_t> &new_shape) const
return this->view_as(inferred_shape, new_strides); return this->view_as(inferred_shape, new_strides);
} }
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const { std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>(); std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage; tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides); tensor->_desc = TensorDesc::create(this->dtype(), new_shape);
tensor->_offset = this->_offset; tensor->_offset = this->_offset;
return tensor; return tensor;
} }
std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape) const { std::shared_ptr<Tensor> Tensor::view_as(const std::vector<size_t> &new_shape, const std::vector<ptrdiff_t> &new_strides) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>(); std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_storage = this->_storage; tensor->_storage = this->_storage;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape); tensor->_desc = TensorDesc::create(this->dtype(), new_shape, new_strides);
tensor->_offset = this->_offset; tensor->_offset = this->_offset;
return tensor; return tensor;
} }
......
...@@ -63,6 +63,7 @@ void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) { ...@@ -63,6 +63,7 @@ void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) {
this->_shape = new_shape; this->_shape = new_shape;
this->_strides = new_strides; this->_strides = new_strides;
this->resetDesc(); this->resetDesc();
this->computeTensorDesHash();
} }
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) { std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
...@@ -89,6 +90,7 @@ void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) { ...@@ -89,6 +90,7 @@ void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
this->_shape = new_shape; this->_shape = new_shape;
this->_strides = new_strides; this->_strides = new_strides;
this->resetDesc(); this->resetDesc();
this->computeTensorDesHash();
} }
std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) { std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
...@@ -108,6 +110,7 @@ void TensorDesc::permute(const std::vector<size_t> &order) { ...@@ -108,6 +110,7 @@ void TensorDesc::permute(const std::vector<size_t> &order) {
this->_shape = new_shape; this->_shape = new_shape;
this->_strides = new_strides; this->_strides = new_strides;
this->resetDesc(); this->resetDesc();
this->computeTensorDesHash();
} }
std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) { std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
......
...@@ -119,4 +119,9 @@ inline uint16_t f32_to_bf16(float val) { ...@@ -119,4 +119,9 @@ inline uint16_t f32_to_bf16(float val) {
return bf16_bits; return bf16_bits;
} }
// Hash combine utility (similar to boost::hash_combine)
inline void hash_combine(size_t &seed, size_t value) {
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment