interface.cpp 293 Bytes
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <torch/extension.h>


at::Tensor multi_tensor_l2norm_cuda(
  int chunk_size,
  std::vector<std::vector<at::Tensor>> tensor_lists);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("l2norm", &multi_tensor_l2norm_cuda,
        "Computes L2 norm for a list of contiguous tensors");
}