Commit b58786f2 authored by Anthony Chang's avatar Anthony Chang
Browse files

serialize tensor object in readable format

parent d1567094
......@@ -1057,3 +1057,13 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
} // namespace ck
namespace std {
inline std::ostream& operator<<(std::ostream& os, const ck::half_t& p)
{
os << static_cast<float>(p);
return os;
}
} // namespace std
......@@ -470,3 +470,49 @@ struct Tensor
Descriptor mDesc;
Data mData;
};
template <typename T>
void SerializeTensor(std::ostream& os,
const Tensor<T>& tensor,
std::vector<size_t>& idx,
size_t rank)
{
if(rank == tensor.mDesc.GetNumOfDimension() - 1)
{
os << "(";
for(size_t i = 0; i < rank; i++)
{
os << idx[i] << (i == rank - 1 ? ", x) : " : ", ");
}
size_t dimz = tensor.mDesc.GetLengths()[rank];
os << "[";
for(size_t i = 0; i < dimz; i++)
{
idx[rank] = i;
os << tensor(idx) << (i == dimz - 1 ? "]" : ", ");
}
os << "\n";
return;
}
for(size_t i = 0; i < tensor.mDesc.GetLengths()[rank]; i++)
{
idx[rank] = i;
SerializeTensor(os, tensor, idx, rank + 1);
}
}
// Example format for Tensor(2, 2, 3):
// (0, 0, x) : [0, 1, 2]
// (0, 1, x) : [3, 4, 5]
// (1, 0, x) : [6, 7, 8]
// (1, 1, x) : [9, 10, 11]
template <typename T>
std::ostream& operator<<(std::ostream& os, const Tensor<T>& tensor)
{
std::vector<size_t> idx(tensor.mDesc.GetNumOfDimension(), 0);
SerializeTensor(os, tensor, idx, 0);
return os;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment