wall_collision.cpp 5.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include <gtest/gtest.h>
#include <torchaudio/csrc/rir/wall.h>

using namespace torchaudio::rir;

struct CollisionTestParam {
  // Input
  torch::Tensor origin;
  torch::Tensor direction;
  // Expected
  torch::Tensor hit_point;
  int next_wall_index;
  float hit_distance;
};

CollisionTestParam par(
    torch::ArrayRef<float> origin,
    torch::ArrayRef<float> direction,
    torch::ArrayRef<float> hit_point,
    int next_wall_index,
    float hit_distance) {
  auto dir = torch::tensor(direction);
  return {
      torch::tensor(origin),
      dir / dir.norm(),
      torch::tensor(hit_point),
      next_wall_index,
      hit_distance};
}

//////////////////////////////////////////////////////////////////////////////
// 2D test
//////////////////////////////////////////////////////////////////////////////

class Simple2DRoomCollisionTest
    : public ::testing::TestWithParam<CollisionTestParam> {};

TEST_P(Simple2DRoomCollisionTest, CollisionTest2D) {
  //
  //  ^
  //  |        3
  //  |     ______
  //  |    |      |
  //  |  0 |      | 1
  //  |    |______|
  //  |        2
  // -+---------------->
  //
  auto room = torch::tensor({1, 1});

  auto param = GetParam();
  auto [hit_point, next_wall_index, hit_distance] =
      find_collision_wall<float, 2>(room, param.origin, param.direction);

  EXPECT_EQ(param.next_wall_index, next_wall_index);
  EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
  EXPECT_TRUE(torch::allclose(
      param.hit_point, hit_point, /*rtol*/ 1e-05, /*atol*/ 1e-07));
}

#define ISQRT2 0.70710678118

INSTANTIATE_TEST_CASE_P(
    Collision2DTests,
    Simple2DRoomCollisionTest,
    ::testing::Values(
        // From 0
        par({0.0, 0.5}, {1.0, 0.0}, {1.0, 0.5}, 1, 1.0),
        par({0.0, 0.5}, {1.0, -1.}, {0.5, 0.0}, 2, ISQRT2),
        par({0.0, 0.5}, {1.0, 1.0}, {0.5, 1.0}, 3, ISQRT2),
        // From 1
        par({1.0, 0.5}, {-1., 0.0}, {0.0, 0.5}, 0, 1.0),
        par({1.0, 0.5}, {-1., -1.}, {0.5, 0.0}, 2, ISQRT2),
        par({1.0, 0.5}, {-1., 1.0}, {0.5, 1.0}, 3, ISQRT2),
        // From 2
        par({0.5, 0.0}, {-1., 1.0}, {0.0, 0.5}, 0, ISQRT2),
        par({0.5, 0.0}, {1.0, 1.0}, {1.0, 0.5}, 1, ISQRT2),
        par({0.5, 0.0}, {0.0, 1.0}, {0.5, 1.0}, 3, 1.0),
        // From 3
        par({0.5, 1.0}, {-1., -1.}, {0.0, 0.5}, 0, ISQRT2),
        par({0.5, 1.0}, {1.0, -1.}, {1.0, 0.5}, 1, ISQRT2),
        par({0.5, 1.0}, {0.0, -1.}, {0.5, 0.0}, 2, 1.0)));

//////////////////////////////////////////////////////////////////////////////
// 3D test
//////////////////////////////////////////////////////////////////////////////

class Simple3DRoomCollisionTest
    : public ::testing::TestWithParam<CollisionTestParam> {};

TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {
  //  y                       z
  //  ^                       ^
  //  |        3              |      y
  //  |     ______            |     /
  //  |    |      |           |    /
  //  |  0 |      | 1         |  ______
  //  |    |______|           | /     /  4: floor, 5: ceiling
  //  |        2              |/     /
  // -+----------------> x   -+--------------> x
  //
  auto room = torch::tensor({1, 1, 1});

  auto param = GetParam();
  auto [hit_point, next_wall_index, hit_distance] =
      find_collision_wall<float, 3>(room, param.origin, param.direction);

  EXPECT_EQ(param.next_wall_index, next_wall_index);
  EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
  EXPECT_TRUE(torch::allclose(
      param.hit_point, hit_point, /*rtol*/ 1e-05, /*atol*/ 1e-07));
}

INSTANTIATE_TEST_CASE_P(
    Collision3DTests,
    Simple3DRoomCollisionTest,
    ::testing::Values(
        // From 0
        par({0, .5, .5}, {1.0, 0.0, 0.0}, {1., .5, .5}, 1, 1.0),
        par({0, .5, .5}, {1.0, -1., 0.0}, {.5, .0, .5}, 2, ISQRT2),
        par({0, .5, .5}, {1.0, 1.0, 0.0}, {.5, 1., .5}, 3, ISQRT2),
        par({0, .5, .5}, {1.0, 0.0, -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({0, .5, .5}, {1.0, 0.0, 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 1
        par({1, .5, .5}, {-1., 0.0, 0.0}, {.0, .5, .5}, 0, 1.0),
        par({1, .5, .5}, {-1., -1., 0.0}, {.5, .0, .5}, 2, ISQRT2),
        par({1, .5, .5}, {-1., 1.0, 0.0}, {.5, 1., .5}, 3, ISQRT2),
        par({1, .5, .5}, {-1., 0.0, -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({1, .5, .5}, {-1., 0.0, 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 2
        par({.5, 0, .5}, {-1., 1.0, 0.0}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, 0, .5}, {1.0, 1.0, 0.0}, {1., .5, .5}, 1, ISQRT2),
        par({.5, 0, .5}, {0.0, 1.0, 0.0}, {.5, 1., .5}, 3, 1.0),
        par({.5, 0, .5}, {0.0, 1.0, -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({.5, 0, .5}, {0.0, 1.0, 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 3
        par({.5, 1, .5}, {-1., -1., 0.0}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, 1, .5}, {1.0, -1., 0.0}, {1., .5, .5}, 1, ISQRT2),
        par({.5, 1, .5}, {0.0, -1., 0.0}, {.5, .0, .5}, 2, 1.0),
        par({.5, 1, .5}, {0.0, -1., -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({.5, 1, .5}, {0.0, -1., 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 4
        par({.5, .5, 0}, {-1., 0.0, 1.0}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, .5, 0}, {1.0, 0.0, 1.0}, {1., .5, .5}, 1, ISQRT2),
        par({.5, .5, 0}, {0.0, -1., 1.0}, {.5, .0, .5}, 2, ISQRT2),
        par({.5, .5, 0}, {0.0, 1.0, 1.0}, {.5, 1., .5}, 3, ISQRT2),
        par({.5, .5, 0}, {0.0, 0.0, 1.0}, {.5, .5, 1.}, 5, 1.0),
        // From 5
        par({.5, .5, 1}, {-1., 0.0, -1.}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, .5, 1}, {1.0, 0.0, -1.}, {1., .5, .5}, 1, ISQRT2),
        par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
        par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
        par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));