transform.cpp 4.86 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
#include "../tensor.hpp"
#include "../utils.hpp"
#include <algorithm>
#include <numeric>
#include <vector>

PanZezhong's avatar
PanZezhong committed
7
std::shared_ptr<Tensor> Tensor::sliceImpl(const std::vector<SliceParams> &slices) const {
PanZezhong's avatar
init  
PanZezhong committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();

    auto new_shape = std::vector<size_t>(this->_shape);
    ptrdiff_t offset = 0;

    for (const auto &slice : slices) {
        ASSERT(slice.len > 0);
        ASSERT(this->_shape[slice.dim] >= slice.start + slice.len);
        new_shape[slice.dim] = slice.len;
        offset += slice.start * this->_strides[slice.dim];
    }

    tensor->_dtype = this->_dtype;
    tensor->_shape = new_shape;
    tensor->_strides = std::vector<ptrdiff_t>(this->_strides);
    tensor->_offset = offset * dsize(this->_dtype);
    tensor->_data = static_cast<char *>(this->_data) + tensor->_offset;

    tensor->_size = std::accumulate(new_shape.begin(), new_shape.end(),
                                    dsize(this->_dtype), std::multiplies<size_t>());
    tensor->storage = this->storage;
    infiniopCreateTensorDescriptor(&tensor->_desc, tensor->_shape.size(), tensor->_shape.data(),
                                   tensor->_strides.data(), tensor->_dtype);
    return tensor;
}

std::shared_ptr<Tensor> Tensor::slice(size_t dim, size_t start, size_t len) {
PanZezhong's avatar
PanZezhong committed
35
    return this->sliceImpl({{dim, start, len}});
PanZezhong's avatar
init  
PanZezhong committed
36
37
38
}

std::shared_ptr<Tensor const> Tensor::slice(size_t dim, size_t start, size_t len) const {
PanZezhong's avatar
PanZezhong committed
39
    return this->sliceImpl({{dim, start, len}});
PanZezhong's avatar
init  
PanZezhong committed
40
41
42
}

std::shared_ptr<Tensor> Tensor::slice(const std::vector<SliceParams> &slices) {
PanZezhong's avatar
PanZezhong committed
43
    return this->sliceImpl(slices);
PanZezhong's avatar
init  
PanZezhong committed
44
45
46
}

std::shared_ptr<Tensor const> Tensor::slice(const std::vector<SliceParams> &slices) const {
PanZezhong's avatar
PanZezhong committed
47
    return this->sliceImpl(slices);
PanZezhong's avatar
init  
PanZezhong committed
48
49
}

PanZezhong's avatar
PanZezhong committed
50
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
PanZezhong's avatar
init  
PanZezhong committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    ASSERT(dim_start <= dim_end && dim_end < this->_shape.size());
    if (dim_start == dim_end) {
        return shared_from_this();
    }

    auto new_shape = std::vector<size_t>();
    auto new_strides = std::vector<ptrdiff_t>();
    for (size_t i = 0; i < dim_start; i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    for (size_t i = dim_start + 1; i <= dim_end; i++) {
        ASSERT_EQ(this->_strides[i - 1], ptrdiff_t(this->_shape[i]) * this->_strides[i]);
    }
    new_shape.push_back(std::accumulate(this->_shape.begin() + dim_start, this->_shape.begin() + dim_end + 1, 1, std::multiplies<size_t>()));
    new_strides.push_back(this->_strides[dim_end]);
    for (size_t i = dim_end + 1; i < this->_shape.size(); i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
    infiniopDestroyTensorDescriptor(this->_desc);
    infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
                                   this->_strides.data(), this->_dtype);

    return shared_from_this();
}

PanZezhong's avatar
PanZezhong committed
80
std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
PanZezhong's avatar
init  
PanZezhong committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>()));
    auto new_shape = std::vector<size_t>();
    auto new_strides = std::vector<ptrdiff_t>();
    for (size_t i = 0; i < dim; i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    for (size_t i = 0; i < dims.size(); i++) {
        new_shape.push_back(dims[i]);
        new_strides.push_back(this->_strides[dim] * this->_shape[dim] / std::accumulate(dims.begin(), dims.begin() + i + 1, 1, std::multiplies<size_t>()));
    }
    for (size_t i = dim + 1; i < this->_shape.size(); i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
    infiniopDestroyTensorDescriptor(this->_desc);
    infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
                                   this->_strides.data(), this->_dtype);
    return shared_from_this();
}

std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
    ASSERT_EQ(this->_shape.size(), order.size());
    auto new_shape = std::vector<size_t>(order.size());
    auto new_strides = std::vector<ptrdiff_t>(order.size());
    for (size_t i = 0; i < order.size(); i++) {
        ASSERT(std::find(order.begin(), order.end(), i) != order.end());
        new_shape[i] = this->_shape[order[i]];
        new_strides[i] = this->_strides[order[i]];
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
    infiniopDestroyTensorDescriptor(this->_desc);
    infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
                                   this->_strides.data(), this->_dtype);
    return shared_from_this();
}