nerfacc.cpp 5.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// This file contains only Python bindings
#include "include/data_spec.hpp"

#include <torch/extension.h>


// scan
torch::Tensor inclusive_sum(
    torch::Tensor chunk_starts,
    torch::Tensor chunk_cnts,
    torch::Tensor inputs,
    bool normalize,
    bool backward);
torch::Tensor exclusive_sum(
    torch::Tensor chunk_starts,
    torch::Tensor chunk_cnts,
    torch::Tensor inputs,
    bool normalize,
    bool backward);
torch::Tensor inclusive_prod_forward(
    torch::Tensor chunk_starts,
    torch::Tensor chunk_cnts,
    torch::Tensor inputs);
torch::Tensor inclusive_prod_backward(
    torch::Tensor chunk_starts,
    torch::Tensor chunk_cnts,
    torch::Tensor inputs,
    torch::Tensor outputs,
    torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_forward(
    torch::Tensor chunk_starts,
    torch::Tensor chunk_cnts,
    torch::Tensor inputs);
torch::Tensor exclusive_prod_backward(
    torch::Tensor chunk_starts,
    torch::Tensor chunk_cnts,
    torch::Tensor inputs,
    torch::Tensor outputs,
    torch::Tensor grad_outputs);

Ruilong Li's avatar
Ruilong Li committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
torch::Tensor inclusive_sum_cub(
    torch::Tensor ray_indices,
    torch::Tensor inputs,
    bool backward);
torch::Tensor exclusive_sum_cub(
    torch::Tensor indices,
    torch::Tensor inputs,
    bool backward);
torch::Tensor inclusive_prod_cub_forward(
    torch::Tensor indices,
    torch::Tensor inputs);
torch::Tensor inclusive_prod_cub_backward(
    torch::Tensor indices,
    torch::Tensor inputs,
    torch::Tensor outputs,
    torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_cub_forward(
    torch::Tensor indices,
    torch::Tensor inputs);
torch::Tensor exclusive_prod_cub_backward(
    torch::Tensor indices,
    torch::Tensor inputs,
    torch::Tensor outputs,
    torch::Tensor grad_outputs);

66
67
68
69
70
71
72
73
// grid
std::vector<torch::Tensor> ray_aabb_intersect(
    const torch::Tensor rays_o, // [n_rays, 3]
    const torch::Tensor rays_d, // [n_rays, 3]
    const torch::Tensor aabbs,  // [n_aabbs, 6]
    const float near_plane,
    const float far_plane, 
    const float miss_value);
74
std::tuple<RaySegmentsSpec, RaySegmentsSpec, torch::Tensor> traverse_grids(
75
76
77
    // rays
    const torch::Tensor rays_o, // [n_rays, 3]
    const torch::Tensor rays_d, // [n_rays, 3]
78
    const torch::Tensor rays_mask,   // [n_rays]
79
80
81
82
    // grids
    const torch::Tensor binaries,  // [n_grids, resx, resy, resz]
    const torch::Tensor aabbs,     // [n_grids, 6]
    // intersections
83
84
    const torch::Tensor t_sorted,  // [n_rays, n_grids * 2]
    const torch::Tensor t_indices,  // [n_rays, n_grids * 2]
85
86
87
88
89
90
91
    const torch::Tensor hits,    // [n_rays, n_grids]
    // options
    const torch::Tensor near_planes,
    const torch::Tensor far_planes,
    const float step_size,
    const float cone_angle,
    const bool compute_intervals,
92
93
94
95
    const bool compute_samples,
    const bool compute_terminate_planes,
    const int32_t traverse_steps_limit, // <= 0 means no limit
    const bool over_allocate); // over allocate the memory for intervals and samples
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

// pdf
std::vector<RaySegmentsSpec> importance_sampling(
    RaySegmentsSpec ray_segments,
    torch::Tensor cdfs,                 
    torch::Tensor n_intervels_per_ray,  
    bool stratified);
std::vector<RaySegmentsSpec> importance_sampling(
    RaySegmentsSpec ray_segments,
    torch::Tensor cdfs,                  
    int64_t n_intervels_per_ray,
    bool stratified);
std::vector<torch::Tensor> searchsorted(
    RaySegmentsSpec query,
    RaySegmentsSpec key);

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
112
113
114
115
116
117
118
119
120
121
122
123
124
// cameras
torch::Tensor opencv_lens_undistortion(
    const torch::Tensor& uv,      // [..., 2]
    const torch::Tensor& params,  // [..., 6]
    const float eps,
    const int max_iterations);
torch::Tensor opencv_lens_undistortion_fisheye(
    const torch::Tensor& uv,      // [..., 2]
    const torch::Tensor& params,  // [..., 4]
    const float criteria_eps,
    const int criteria_iters);


125
126
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
127
128
129
130
131
132
    _REG_FUNC(inclusive_sum);
    _REG_FUNC(exclusive_sum);
    _REG_FUNC(inclusive_prod_forward);
    _REG_FUNC(inclusive_prod_backward);
    _REG_FUNC(exclusive_prod_forward);
    _REG_FUNC(exclusive_prod_backward);
Ruilong Li's avatar
Ruilong Li committed
133
134
135
136
137
138
139

    _REG_FUNC(inclusive_sum_cub);
    _REG_FUNC(exclusive_sum_cub);
    _REG_FUNC(inclusive_prod_cub_forward);
    _REG_FUNC(inclusive_prod_cub_backward);
    _REG_FUNC(exclusive_prod_cub_forward);
    _REG_FUNC(exclusive_prod_cub_backward);
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
140
141
142
143

    _REG_FUNC(ray_aabb_intersect);
    _REG_FUNC(traverse_grids);
    _REG_FUNC(searchsorted);
144

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
145
    _REG_FUNC(opencv_lens_undistortion);
146
    _REG_FUNC(opencv_lens_undistortion_fisheye);  // TODO: check this function.
147
148
#undef _REG_FUNC

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
149
150
    m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
    m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
151

Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
152
153
154
155
156
    py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
        .def(py::init<>())
        .def_readwrite("vals", &RaySegmentsSpec::vals)
        .def_readwrite("is_left", &RaySegmentsSpec::is_left)
        .def_readwrite("is_right", &RaySegmentsSpec::is_right)
157
        .def_readwrite("is_valid", &RaySegmentsSpec::is_valid)
Ruilong Li(李瑞龙)'s avatar
Ruilong Li(李瑞龙) committed
158
159
160
        .def_readwrite("chunk_starts", &RaySegmentsSpec::chunk_starts)
        .def_readwrite("chunk_cnts", &RaySegmentsSpec::chunk_cnts)
        .def_readwrite("ray_indices", &RaySegmentsSpec::ray_indices);
161
}