Commit 0cfbefce authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Add Tensor<>::data()

parent 31875367
......@@ -3,13 +3,13 @@
#pragma once
#include <thread>
#include <vector>
#include <numeric>
#include <algorithm>
#include <utility>
#include <cassert>
#include <iostream>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include "ck/utility/data_type.hpp"
......@@ -235,6 +235,8 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
template <typename T>
struct Tensor
{
using Data = std::vector<T>;
template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
{
......@@ -427,14 +429,18 @@ struct Tensor
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
typename std::vector<T>::iterator begin() { return mData.begin(); }
typename Data::iterator begin() { return mData.begin(); }
typename Data::iterator end() { return mData.end(); }
typename Data::pointer data() { return mData.data(); }
typename std::vector<T>::iterator end() { return mData.end(); }
typename Data::const_iterator begin() const { return mData.begin(); }
typename std::vector<T>::const_iterator begin() const { return mData.begin(); }
typename Data::const_iterator end() const { return mData.end(); }
typename std::vector<T>::const_iterator end() const { return mData.end(); }
typename Data::const_pointer data() const { return mData.data(); }
HostTensorDescriptor mDesc;
std::vector<T> mData;
Data mData;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment