transform.cpp 4.92 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
#include "../tensor.hpp"
#include "../utils.hpp"
#include <algorithm>
#include <numeric>
#include <vector>

PanZezhong's avatar
PanZezhong committed
7
std::shared_ptr<Tensor> Tensor::sliceImpl(const std::vector<SliceParams> &slices) const {
PanZezhong's avatar
init  
PanZezhong committed
8
9
    std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();

PanZezhong's avatar
PanZezhong committed
10
    auto new_shape = std::vector<size_t>(this->shape());
PanZezhong's avatar
init  
PanZezhong committed
11
12
13
14
    ptrdiff_t offset = 0;

    for (const auto &slice : slices) {
        ASSERT(slice.len > 0);
PanZezhong's avatar
PanZezhong committed
15
        ASSERT(this->shape()[slice.dim] >= slice.start + slice.len);
PanZezhong's avatar
init  
PanZezhong committed
16
        new_shape[slice.dim] = slice.len;
PanZezhong's avatar
PanZezhong committed
17
        offset += slice.start * this->strides()[slice.dim];
PanZezhong's avatar
init  
PanZezhong committed
18
19
    }

PanZezhong's avatar
PanZezhong committed
20
21
    tensor->_desc = TensorDesc::create(this->dtype(), new_shape, this->strides());
    tensor->_offset = offset * dsize(this->dtype()) + this->_offset;
PanZezhong's avatar
PanZezhong committed
22
    tensor->_storage = this->_storage;
PanZezhong's avatar
init  
PanZezhong committed
23
24
25
26
    return tensor;
}

std::shared_ptr<Tensor> Tensor::slice(size_t dim, size_t start, size_t len) {
PanZezhong's avatar
PanZezhong committed
27
    return this->sliceImpl({{dim, start, len}});
PanZezhong's avatar
init  
PanZezhong committed
28
29
30
}

std::shared_ptr<Tensor const> Tensor::slice(size_t dim, size_t start, size_t len) const {
PanZezhong's avatar
PanZezhong committed
31
    return this->sliceImpl({{dim, start, len}});
PanZezhong's avatar
init  
PanZezhong committed
32
33
34
}

std::shared_ptr<Tensor> Tensor::slice(const std::vector<SliceParams> &slices) {
PanZezhong's avatar
PanZezhong committed
35
    return this->sliceImpl(slices);
PanZezhong's avatar
init  
PanZezhong committed
36
37
38
}

std::shared_ptr<Tensor const> Tensor::slice(const std::vector<SliceParams> &slices) const {
PanZezhong's avatar
PanZezhong committed
39
    return this->sliceImpl(slices);
PanZezhong's avatar
init  
PanZezhong committed
40
41
}

PanZezhong's avatar
PanZezhong committed
42
void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) {
PanZezhong's avatar
init  
PanZezhong committed
43
44
    ASSERT(dim_start <= dim_end && dim_end < this->_shape.size());
    if (dim_start == dim_end) {
PanZezhong's avatar
PanZezhong committed
45
        return;
PanZezhong's avatar
init  
PanZezhong committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    }

    auto new_shape = std::vector<size_t>();
    auto new_strides = std::vector<ptrdiff_t>();
    for (size_t i = 0; i < dim_start; i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    for (size_t i = dim_start + 1; i <= dim_end; i++) {
        ASSERT_EQ(this->_strides[i - 1], ptrdiff_t(this->_shape[i]) * this->_strides[i]);
    }
    new_shape.push_back(std::accumulate(this->_shape.begin() + dim_start, this->_shape.begin() + dim_end + 1, 1, std::multiplies<size_t>()));
    new_strides.push_back(this->_strides[dim_end]);
    for (size_t i = dim_end + 1; i < this->_shape.size(); i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
PanZezhong's avatar
PanZezhong committed
65
    this->resetDesc();
66
    this->computeTensorDesHash();
PanZezhong's avatar
PanZezhong committed
67
}
PanZezhong's avatar
init  
PanZezhong committed
68

PanZezhong's avatar
PanZezhong committed
69
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
wooway777's avatar
wooway777 committed
70
71
72
73
74
75
76
77
    auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
    new_desc->dimMerge(dim_start, dim_end);

    auto tensor = std::make_shared<Tensor>();
    tensor->_storage = _storage;
    tensor->_desc = new_desc;
    tensor->_offset = _offset;
    return tensor;
PanZezhong's avatar
init  
PanZezhong committed
78
79
}

PanZezhong's avatar
PanZezhong committed
80
void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
PanZezhong's avatar
init  
PanZezhong committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>()));
    auto new_shape = std::vector<size_t>();
    auto new_strides = std::vector<ptrdiff_t>();
    for (size_t i = 0; i < dim; i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    for (size_t i = 0; i < dims.size(); i++) {
        new_shape.push_back(dims[i]);
        new_strides.push_back(this->_strides[dim] * this->_shape[dim] / std::accumulate(dims.begin(), dims.begin() + i + 1, 1, std::multiplies<size_t>()));
    }
    for (size_t i = dim + 1; i < this->_shape.size(); i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
PanZezhong's avatar
PanZezhong committed
98
    this->resetDesc();
99
    this->computeTensorDesHash();
PanZezhong's avatar
PanZezhong committed
100
101
102
}

std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
wooway777's avatar
wooway777 committed
103
104
105
106
107
108
109
110
    auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
    new_desc->dimSplit(dim, dims);

    auto tensor = std::make_shared<Tensor>();
    tensor->_storage = _storage;
    tensor->_desc = new_desc;
    tensor->_offset = _offset;
    return tensor;
PanZezhong's avatar
init  
PanZezhong committed
111
112
}

PanZezhong's avatar
PanZezhong committed
113
void TensorDesc::permute(const std::vector<size_t> &order) {
PanZezhong's avatar
init  
PanZezhong committed
114
115
116
117
118
119
120
121
122
123
    ASSERT_EQ(this->_shape.size(), order.size());
    auto new_shape = std::vector<size_t>(order.size());
    auto new_strides = std::vector<ptrdiff_t>(order.size());
    for (size_t i = 0; i < order.size(); i++) {
        ASSERT(std::find(order.begin(), order.end(), i) != order.end());
        new_shape[i] = this->_shape[order[i]];
        new_strides[i] = this->_strides[order[i]];
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
PanZezhong's avatar
PanZezhong committed
124
    this->resetDesc();
125
    this->computeTensorDesHash();
PanZezhong's avatar
PanZezhong committed
126
127
128
}

std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
wooway777's avatar
wooway777 committed
129
130
131
132
133
134
135
136
    auto new_desc = TensorDesc::create(_desc->dtype(), _desc->shape(), _desc->strides());
    new_desc->permute(order);

    auto tensor = std::make_shared<Tensor>();
    tensor->_storage = _storage;
    tensor->_desc = new_desc;
    tensor->_offset = _offset;
    return tensor;
PanZezhong's avatar
PanZezhong committed
137
}