"...resnet50_tensorflow.git" did not exist on "3e9d886d22c7094c3e926b8f4103ef98116bfbbe"
pack_utils.cpp 1.17 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include <torch/torch.h>

namespace at {
namespace native {

at::Tensor revert_varlen_tensor(const Tensor &input, const Tensor &offsets);
at::Tensor get_offsets(const Tensor &input, const Tensor &lengths);

void checkLongTensor(const Tensor &tensor);

at::Tensor set_mask_cpp(const Tensor &_lengths) {
  at::native::checkLongTensor(_lengths);
  int64_t batch_size = _lengths.size(0);
  int64_t *lengths = _lengths.data_ptr<int64_t>();
  int64_t seq_length = (lengths == NULL) ? 0 : lengths[0];
  auto output = torch::empty({seq_length, batch_size}, torch::CPU(at::kByte));
  auto output_data = output.data_ptr<uint8_t>();
  for (int64_t t = 0; t < seq_length; t++) {
    for (int64_t i = 0; i < batch_size; i++) {
      if (lengths[i] > t) {
        output_data[t * batch_size + i] = 1;
      } else {
        output_data[t * batch_size + i] = 0;
      }
    }
  }
  return output;
}

} // namespace native
} // namespace at

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("revert_varlen_tensor", &at::native::revert_varlen_tensor);
  m.def("set_mask_cpp", &at::native::set_mask_cpp);
  m.def("get_offsets", &at::native::get_offsets);
}