Commit 54d3e2f1 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/moe

parents 199f7f71 b8addae2
...@@ -155,7 +155,12 @@ struct HostTensorDescriptor ...@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return space; return space;
} }
std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
const std::vector<std::size_t>& get_lengths() const { return mLens; } const std::vector<std::size_t>& get_lengths() const { return mLens; }
std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
const std::vector<std::size_t>& get_strides() const { return mStrides; } const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is> template <typename... Is>
...@@ -325,8 +330,12 @@ struct HostTensor ...@@ -325,8 +330,12 @@ struct HostTensor
{ {
} }
std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
decltype(auto) get_lengths() const { return mDesc.get_lengths(); } decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
decltype(auto) get_strides() const { return mDesc.get_strides(); } decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment