transform.cpp 4.87 KB
Newer Older
PanZezhong's avatar
init  
PanZezhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include "../tensor.hpp"
#include "../utils.hpp"
#include <algorithm>
#include <numeric>
#include <vector>

std::shared_ptr<Tensor> Tensor::slice_impl(const std::vector<SliceParams> &slices) const {
    std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();

    auto new_shape = std::vector<size_t>(this->_shape);
    ptrdiff_t offset = 0;

    for (const auto &slice : slices) {
        ASSERT(slice.len > 0);
        ASSERT(this->_shape[slice.dim] >= slice.start + slice.len);
        new_shape[slice.dim] = slice.len;
        offset += slice.start * this->_strides[slice.dim];
    }

    tensor->_dtype = this->_dtype;
    tensor->_shape = new_shape;
    tensor->_strides = std::vector<ptrdiff_t>(this->_strides);
    tensor->_offset = offset * dsize(this->_dtype);
    tensor->_data = static_cast<char *>(this->_data) + tensor->_offset;

    tensor->_size = std::accumulate(new_shape.begin(), new_shape.end(),
                                    dsize(this->_dtype), std::multiplies<size_t>());
    tensor->storage = this->storage;
    infiniopCreateTensorDescriptor(&tensor->_desc, tensor->_shape.size(), tensor->_shape.data(),
                                   tensor->_strides.data(), tensor->_dtype);
    return tensor;
}

std::shared_ptr<Tensor> Tensor::slice(size_t dim, size_t start, size_t len) {
    return this->slice_impl({{dim, start, len}});
}

std::shared_ptr<Tensor const> Tensor::slice(size_t dim, size_t start, size_t len) const {
    return this->slice_impl({{dim, start, len}});
}

std::shared_ptr<Tensor> Tensor::slice(const std::vector<SliceParams> &slices) {
    return this->slice_impl(slices);
}

std::shared_ptr<Tensor const> Tensor::slice(const std::vector<SliceParams> &slices) const {
    return this->slice_impl(slices);
}

std::shared_ptr<Tensor> Tensor::dim_merge(size_t dim_start, size_t dim_end) {
    ASSERT(dim_start <= dim_end && dim_end < this->_shape.size());
    if (dim_start == dim_end) {
        return shared_from_this();
    }

    auto new_shape = std::vector<size_t>();
    auto new_strides = std::vector<ptrdiff_t>();
    for (size_t i = 0; i < dim_start; i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    for (size_t i = dim_start + 1; i <= dim_end; i++) {
        ASSERT_EQ(this->_strides[i - 1], ptrdiff_t(this->_shape[i]) * this->_strides[i]);
    }
    new_shape.push_back(std::accumulate(this->_shape.begin() + dim_start, this->_shape.begin() + dim_end + 1, 1, std::multiplies<size_t>()));
    new_strides.push_back(this->_strides[dim_end]);
    for (size_t i = dim_end + 1; i < this->_shape.size(); i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
    infiniopDestroyTensorDescriptor(this->_desc);
    infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
                                   this->_strides.data(), this->_dtype);

    return shared_from_this();
}

std::shared_ptr<Tensor> Tensor::dim_split(size_t dim, const std::vector<size_t> &dims) {
    ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>()));
    auto new_shape = std::vector<size_t>();
    auto new_strides = std::vector<ptrdiff_t>();
    for (size_t i = 0; i < dim; i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    for (size_t i = 0; i < dims.size(); i++) {
        new_shape.push_back(dims[i]);
        new_strides.push_back(this->_strides[dim] * this->_shape[dim] / std::accumulate(dims.begin(), dims.begin() + i + 1, 1, std::multiplies<size_t>()));
    }
    for (size_t i = dim + 1; i < this->_shape.size(); i++) {
        new_shape.push_back(this->_shape[i]);
        new_strides.push_back(this->_strides[i]);
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
    infiniopDestroyTensorDescriptor(this->_desc);
    infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
                                   this->_strides.data(), this->_dtype);
    return shared_from_this();
}

std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
    ASSERT_EQ(this->_shape.size(), order.size());
    auto new_shape = std::vector<size_t>(order.size());
    auto new_strides = std::vector<ptrdiff_t>(order.size());
    for (size_t i = 0; i < order.size(); i++) {
        ASSERT(std::find(order.begin(), order.end(), i) != order.end());
        new_shape[i] = this->_shape[order[i]];
        new_strides[i] = this->_strides[order[i]];
    }
    this->_shape = new_shape;
    this->_strides = new_strides;
    infiniopDestroyTensorDescriptor(this->_desc);
    infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
                                   this->_strides.data(), this->_dtype);
    return shared_from_this();
}