wall_collision.cpp 4.5 KB
Newer Older
1
#include <gtest/gtest.h>
2
#include <libtorchaudio/rir/wall.h>
3
4
5

using namespace torchaudio::rir;

moto's avatar
moto committed
6
7
using DTYPE = double;

8
9
10
11
12
13
14
struct CollisionTestParam {
  // Input
  torch::Tensor origin;
  torch::Tensor direction;
  // Expected
  torch::Tensor hit_point;
  int next_wall_index;
moto's avatar
moto committed
15
  DTYPE hit_distance;
16
17
18
};

CollisionTestParam par(
moto's avatar
moto committed
19
20
21
    torch::ArrayRef<DTYPE> origin,
    torch::ArrayRef<DTYPE> direction,
    torch::ArrayRef<DTYPE> hit_point,
22
    int next_wall_index,
moto's avatar
moto committed
23
    DTYPE hit_distance) {
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  auto dir = torch::tensor(direction);
  return {
      torch::tensor(origin),
      dir / dir.norm(),
      torch::tensor(hit_point),
      next_wall_index,
      hit_distance};
}

//////////////////////////////////////////////////////////////////////////////
// 3D test
//////////////////////////////////////////////////////////////////////////////

class Simple3DRoomCollisionTest
    : public ::testing::TestWithParam<CollisionTestParam> {};

TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {
  //  y                       z
  //  ^                       ^
  //  |        3              |      y
  //  |     ______            |     /
  //  |    |      |           |    /
  //  |  0 |      | 1         |  ______
  //  |    |______|           | /     /  4: floor, 5: ceiling
  //  |        2              |/     /
  // -+----------------> x   -+--------------> x
  //
  auto room = torch::tensor({1, 1, 1});

  auto param = GetParam();
  auto [hit_point, next_wall_index, hit_distance] =
moto's avatar
moto committed
55
      find_collision_wall<DTYPE>(room, param.origin, param.direction);
56
57
58

  EXPECT_EQ(param.next_wall_index, next_wall_index);
  EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
moto's avatar
moto committed
59
60
61
62
63
64
  EXPECT_NEAR(
      param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
  EXPECT_NEAR(
      param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
  EXPECT_NEAR(
      param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
65
66
}

moto's avatar
moto committed
67
68
#define ISQRT2 0.70710678118

69
INSTANTIATE_TEST_CASE_P(
moto's avatar
moto committed
70
    BasicCollisionTests,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    Simple3DRoomCollisionTest,
    ::testing::Values(
        // From 0
        par({0, .5, .5}, {1.0, 0.0, 0.0}, {1., .5, .5}, 1, 1.0),
        par({0, .5, .5}, {1.0, -1., 0.0}, {.5, .0, .5}, 2, ISQRT2),
        par({0, .5, .5}, {1.0, 1.0, 0.0}, {.5, 1., .5}, 3, ISQRT2),
        par({0, .5, .5}, {1.0, 0.0, -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({0, .5, .5}, {1.0, 0.0, 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 1
        par({1, .5, .5}, {-1., 0.0, 0.0}, {.0, .5, .5}, 0, 1.0),
        par({1, .5, .5}, {-1., -1., 0.0}, {.5, .0, .5}, 2, ISQRT2),
        par({1, .5, .5}, {-1., 1.0, 0.0}, {.5, 1., .5}, 3, ISQRT2),
        par({1, .5, .5}, {-1., 0.0, -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({1, .5, .5}, {-1., 0.0, 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 2
        par({.5, 0, .5}, {-1., 1.0, 0.0}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, 0, .5}, {1.0, 1.0, 0.0}, {1., .5, .5}, 1, ISQRT2),
        par({.5, 0, .5}, {0.0, 1.0, 0.0}, {.5, 1., .5}, 3, 1.0),
        par({.5, 0, .5}, {0.0, 1.0, -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({.5, 0, .5}, {0.0, 1.0, 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 3
        par({.5, 1, .5}, {-1., -1., 0.0}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, 1, .5}, {1.0, -1., 0.0}, {1., .5, .5}, 1, ISQRT2),
        par({.5, 1, .5}, {0.0, -1., 0.0}, {.5, .0, .5}, 2, 1.0),
        par({.5, 1, .5}, {0.0, -1., -1.}, {.5, .5, .0}, 4, ISQRT2),
        par({.5, 1, .5}, {0.0, -1., 1.0}, {.5, .5, 1.}, 5, ISQRT2),
        // From 4
        par({.5, .5, 0}, {-1., 0.0, 1.0}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, .5, 0}, {1.0, 0.0, 1.0}, {1., .5, .5}, 1, ISQRT2),
        par({.5, .5, 0}, {0.0, -1., 1.0}, {.5, .0, .5}, 2, ISQRT2),
        par({.5, .5, 0}, {0.0, 1.0, 1.0}, {.5, 1., .5}, 3, ISQRT2),
        par({.5, .5, 0}, {0.0, 0.0, 1.0}, {.5, .5, 1.}, 5, 1.0),
        // From 5
        par({.5, .5, 1}, {-1., 0.0, -1.}, {.0, .5, .5}, 0, ISQRT2),
        par({.5, .5, 1}, {1.0, 0.0, -1.}, {1., .5, .5}, 1, ISQRT2),
        par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
        par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
        par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));
moto's avatar
moto committed
109
110
111
112
113
114
115
116
117
118

INSTANTIATE_TEST_CASE_P(
    CornerCollisionTest,
    Simple3DRoomCollisionTest,
    ::testing::Values(
        par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
        par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
        par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
        par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
        par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)));