wall.h 6.36 KB
Newer Older
moto-meta's avatar
moto-meta committed
1
#pragma once
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <torch/types.h>

#define EPS ((scalar_t)(1e-5))
#define SCALAR(x) ((x).template item<scalar_t>())

namespace torchaudio {
namespace rir {

////////////////////////////////////////////////////////////////////////////////
// Basic Wall implementation
////////////////////////////////////////////////////////////////////////////////

/// Wall helper class. A wall records its own absorption, reflection and
/// scattering coefficient, and exposes a few methods for geometrical operations
/// (e.g. reflection of a ray)
template <typename scalar_t>
struct Wall {
  const torch::Tensor origin;
  const torch::Tensor normal;
  const torch::Tensor scattering;

  const torch::Tensor reflection;

  Wall(
      const torch::ArrayRef<scalar_t>& origin,
      const torch::ArrayRef<scalar_t>& normal,
      const torch::Tensor& absorption,
      const torch::Tensor& scattering)
      : origin(torch::tensor(origin)),
        normal(torch::tensor(normal)),
        scattering(scattering),
        reflection(1. - absorption) {}
};

/// Returns the side (-1, 1 or 0) on which a point lies w.r.t. the wall.
template <typename scalar_t>
int side(const Wall<scalar_t>& wall, const torch::Tensor& pos) {
  auto dot = SCALAR((pos - wall.origin).dot(wall.normal));

  if (dot > EPS) {
    return 1;
  } else if (dot < -EPS) {
    return -1;
  } else {
    return 0;
  }
}

/// Reflects a ray (dir) on the wall. Preserves norm of vector.
template <typename scalar_t>
torch::Tensor reflect(const Wall<scalar_t>& wall, const torch::Tensor& dir) {
  return dir - wall.normal * 2 * dir.dot(wall.normal);
}

/// Returns the cosine angle of a ray (dir) with the normal of the wall
template <typename scalar_t>
scalar_t cosine(const Wall<scalar_t>& wall, const torch::Tensor& dir) {
  return SCALAR(dir.dot(wall.normal) / dir.norm());
}

////////////////////////////////////////////////////////////////////////////////
// Room (multiple walls) and interactions
////////////////////////////////////////////////////////////////////////////////

/// Creates a shoebox room consists of multiple walls.
/// Normals are vectors facing *outwards* the room, and origins are arbitrary
/// corners of each wall.
///
/// Note:
/// The wall has to be ordered in the following way:
/// - parallel walls are next (W/E, S/N, and F/C)
/// - The one closer to the origin must come first. (W -> E, S -> N, F -> C)
/// - The order of wall pair must be W/E, S/N, then F/C because
///   `find_collision_wall` will search in the order x, y, z and
///   wall pairs must be distibguishable on these axis.

/// 3D room
template <typename T>
80
const std::array<Wall<T>, 6> make_room(
moto-meta's avatar
moto-meta committed
81
82
83
    const T& w,
    const T& l,
    const T& h,
84
85
86
87
88
89
90
91
92
    const torch::Tensor& abs,
    const torch::Tensor& scat) {
  using namespace torch::indexing;
#define SLICE(x, i) x.index({Slice(), i})
  return {
      Wall<T>({0, l, 0}, {-1, 0, 0}, SLICE(abs, 0), SLICE(scat, 0)), // West
      Wall<T>({w, 0, 0}, {1, 0, 0}, SLICE(abs, 1), SLICE(scat, 1)), // East
      Wall<T>({0, 0, 0}, {0, -1, 0}, SLICE(abs, 2), SLICE(scat, 2)), // South
      Wall<T>({w, l, 0}, {0, 1, 0}, SLICE(abs, 3), SLICE(scat, 3)), // North
moto's avatar
moto committed
93
94
      Wall<T>({w, 0, 0}, {0, 0, -1}, SLICE(abs, 4), SLICE(scat, 4)), // Floor
      Wall<T>({w, 0, h}, {0, 0, 1}, SLICE(abs, 5), SLICE(scat, 5)) // Ceiling
95
96
97
98
99
100
101
102
103
104
105
  };
#undef SLICE
}

/// Find a wall that the given ray hits.
/// The room is assumed to be shoebox room and the walls are constructed
/// in the order used in `make_room`.
/// The room is shoebox-shape and the ray travels infinite distance
/// so that it does hit one of the walls.
/// See also:
/// https://github.com/LCAV/pyroomacoustics/blob/df8af24c88a87b5d51c6123087cd3cd2d361286a/pyroomacoustics/libroom_src/room.cpp#L609-L716
moto's avatar
moto committed
106
template <typename scalar_t>
107
108
109
110
111
112
113
114
115
std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
    const torch::Tensor& room,
    const torch::Tensor& origin,
    const torch::Tensor& direction // Unit-vector
) {
#define BOOL(x) torch::all(x).template item<bool>()
#define INSIDE(x, y) (BOOL(-EPS < (x)) && BOOL((x) < (y + EPS)))

  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
moto's avatar
moto committed
116
117
      3 == room.size(0),
      "Expected room to be 3 dimension, but received ",
118
119
      room.sizes());
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
moto's avatar
moto committed
120
121
      3 == origin.size(0),
      "Expected origin to be 3 dimension, but received ",
122
123
      origin.sizes());
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
moto's avatar
moto committed
124
125
      3 == direction.size(0),
      "Expected direction to be 3 dimension, but received ",
126
127
128
129
130
131
132
133
134
135
136
      direction.sizes());
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
      BOOL(room > 0), "Room size should be greater than zero. Found: ", room);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
      INSIDE(origin, room),
      "The origin of ray must be inside the room. Origin: ",
      origin,
      ", room: ",
      room);

  // i is the coordinate in the collision is searched.
moto's avatar
moto committed
137
  for (unsigned int i = 0; i < 3; ++i) {
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    auto dir0 = SCALAR(direction[i]);
    auto abs_dir0 = std::abs(dir0);

    // If the ray is almost parallel to a plane, then we delegate the
    // computation to the other planes.
    if (abs_dir0 < EPS) {
      continue;
    }

    // Check the distance to the facing wall along the coordinate.
    scalar_t distance = (dir0 < 0.)
        ? SCALAR(origin[i]) // Going towards origin
        : SCALAR(room[i] - origin[i]); // Going away from origin
    auto ratio = distance / abs_dir0;
    int i_increment = dir0 > 0.;

    // Compute the intersection of ray and the wall
    auto intersection = origin + ratio * direction;

    // The intersection can be within the room or outside.
    // If it's inside, the collision point is found.
    //      ^
    //      |           |   Not Good
    //   ---+-----------+---x----
    //      |           |  /
    //      |           | /
    //      |           |/
    //      |           x  Found
    //      |          /|
    //      |         / |
    //      |        o  |
    //      |           |
    //   ---+-----------+-------->
    //     O|           |
    //

    if (INSIDE(intersection, room)) {
      int i_wall = 2 * i + i_increment;
      auto dist = SCALAR((intersection - origin).norm());
      return std::make_tuple(intersection, i_wall, dist);
    }
  }
  // This should not happen
  TORCH_INTERNAL_ASSERT(
      false,
      "Failed to find the intersection. room: ",
      room,
      " origin: ",
      origin,
      " direction: ",
      direction);
#undef INSIDE
#undef BOOL
}
} // namespace rir
} // namespace torchaudio

#undef EPS
#undef SCALAR