Unverified Commit fa78fb64 authored by moto's avatar moto Committed by GitHub
Browse files

Add Ray Tracing (#3604) (#2850) (#3655)

Summary:
Revamped version of https://github.com/pytorch/audio/pull/3234
(which was also revamp of https://github.com/pytorch/audio/pull/2850)
parent dde08ba1
......@@ -35,4 +35,5 @@ Room Impulse Response Simulation
:toctree: generated
:nosignatures:
ray_tracing
simulate_rir_ism
......@@ -220,9 +220,13 @@ class RayTracer {
if (NORM(to_mic - dir * impact_distance) < mic_radius + EPS) {
// The length of this last hop
auto travel_dist_at_mic = travel_dist + std::abs(impact_distance);
auto bin_idx = get_bin_idx(travel_dist_at_mic);
if (bin_idx >= histograms.size(1)) {
continue;
}
auto coeff = get_energy_coeff(travel_dist_at_mic, mic_radius_sq);
auto energy = energies / coeff;
histograms[mic_idx][get_bin_idx(travel_dist_at_mic)] += energy;
histograms[mic_idx][bin_idx] += energy;
}
}
}
......@@ -230,7 +234,7 @@ class RayTracer {
travel_dist += hit_distance;
energies *= wall.reflection;
// Let's shoot the scattered ray induced by the rebound on the wall
// Let's shoot the scattered ray induced by the rebound on the wall
if (do_scattering) {
scat_ray(histograms, wall, energies, origin, hit_point, travel_dist);
energies *= (1. - wall.scattering);
......
......@@ -19,7 +19,6 @@ struct Wall {
const torch::Tensor origin;
const torch::Tensor normal;
const torch::Tensor scattering;
const torch::Tensor reflection;
Wall(
......@@ -27,8 +26,8 @@ struct Wall {
const torch::ArrayRef<scalar_t>& normal,
const torch::Tensor& absorption,
const torch::Tensor& scattering)
: origin(torch::tensor(origin)),
normal(torch::tensor(normal)),
: origin(torch::tensor(origin).to(scattering.dtype())),
normal(torch::tensor(normal).to(scattering.dtype())),
scattering(scattering),
reflection(1. - absorption) {}
};
......@@ -137,7 +136,6 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
for (unsigned int i = 0; i < 3; ++i) {
auto dir0 = SCALAR(direction[i]);
auto abs_dir0 = std::abs(dir0);
// If the ray is almost parallel to a plane, then we delegate the
// computation to the other planes.
if (abs_dir0 < EPS) {
......@@ -148,6 +146,10 @@ std::tuple<torch::Tensor, int, scalar_t> find_collision_wall(
scalar_t distance = (dir0 < 0.)
? SCALAR(origin[i]) // Going towards origin
: SCALAR(room[i] - origin[i]); // Going away from origin
// sometimes origin is slightly outside of room
if (distance < 0) {
distance = 0.;
}
auto ratio = distance / abs_dir0;
int i_increment = dir0 > 0.;
......
......@@ -7,7 +7,7 @@ from ._dsp import (
oscillator_bank,
sinc_impulse_response,
)
from ._rir import simulate_rir_ism
from ._rir import ray_tracing, simulate_rir_ism
from .functional import barkscale_fbanks, chroma_filterbank
......@@ -20,6 +20,7 @@ __all__ = [
"filter_waveform",
"frequency_impulse_response",
"oscillator_bank",
"ray_tracing",
"sinc_impulse_response",
"simulate_rir_ism",
]
......@@ -133,20 +133,24 @@ def _adjust_coeff(coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor
"""
num_walls = 6
if isinstance(coeffs, float):
if coeffs < 0:
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
return torch.full((1, num_walls), coeffs)
if isinstance(coeffs, Tensor):
if torch.any(coeffs < 0):
raise ValueError(f"`{name}` must be non-negative. Found: {coeffs}")
if coeffs.ndim == 1:
if coeffs.numel() != num_walls:
raise ValueError(
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor. "
f"Found the shape {coeffs.shape}."
)
return coeffs.unsqueeze(0)
if coeffs.ndim == 2:
if coeffs.shape != (7, num_walls):
if coeffs.shape[1] != num_walls:
raise ValueError(
f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
f"Found the shape {coeffs.shape}."
f"The shape of `{name}` must be (NUM_BANDS, {num_walls}) when it "
f"is a 2D Tensor. Found: {coeffs.shape}."
)
return coeffs
raise TypeError(f"`{name}` must be float or Tensor.")
......@@ -169,7 +173,7 @@ def _validate_inputs(
if not (source.ndim == 1 and source.numel() == 3):
raise ValueError(f"`source` must be 1D Tensor with 3 elements. Found {source.shape}.")
if not (mic_array.ndim == 2 and mic_array.shape[1] == 3):
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
raise ValueError(f"`mic_array` must be a 2D Tensor with shape (num_channels, 3). Found {mic_array.shape}.")
def simulate_rir_ism(
......@@ -270,3 +274,106 @@ def simulate_rir_ism(
rir = rir[..., :output_length]
return rir
def ray_tracing(
room: torch.Tensor,
source: torch.Tensor,
mic_array: torch.Tensor,
num_rays: int,
absorption: Union[float, torch.Tensor] = 0.0,
scattering: Union[float, torch.Tensor] = 0.0,
mic_radius: float = 0.5,
sound_speed: float = 343.0,
energy_thres: float = 1e-7,
time_thres: float = 10.0,
hist_bin_size: float = 0.004,
) -> torch.Tensor:
r"""Compute energy histogram via ray tracing.
The implementation is based on *pyroomacoustics* :cite:`scheibler2018pyroomacoustics`.
``num_rays`` rays are casted uniformly in all directions from the source;
when a ray intersects a wall, it is reflected and part of its energy is absorbed.
It is also scattered (sent directly to the microphone(s)) according to the ``scattering``
coefficient.
When a ray is close to the microphone, its current energy is recorded in the output
histogram for that given time slot.
.. devices:: CPU
.. properties:: TorchScript
Args:
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
three dimensions of the room.
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
absorption (float or torch.Tensor, optional): The absorption coefficients of wall materials.
(Default: ``0.0``).
If the type is ``float``, the absorption coefficient is identical to all walls and
all frequencies.
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, representing absorption
coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, and
``"ceiling"``, respectively.
If ``absorption`` is a 2D Tensor, the shape must be `(num_bands, 6)`.
``num_bands`` is the number of frequency bands (usually 7).
scattering(float or torch.Tensor, optional): The scattering coefficients of wall materials. (Default: ``0.0``)
The shape and type of this parameter is the same as for ``absorption``.
mic_radius(float, optional): The radius of the microphone in meters. (Default: 0.5)
sound_speed (float, optional): The speed of sound in meters per second. (Default: ``343.0``)
energy_thres (float, optional): The energy level below which we stop tracing a ray. (Default: ``1e-7``)
The initial energy of each ray is ``2 / num_rays``.
time_thres (float, optional): The maximal duration for which rays are traced. (Unit: seconds) (Default: 10.0)
hist_bin_size (float, optional): The size of each bin in the output histogram. (Unit: seconds) (Default: 0.004)
Returns:
(torch.Tensor): The 3D histogram(s) where the energy of the traced ray is recorded.
Each bin corresponds to a given time slot.
The shape is `(channel, num_bands, num_bins)`, where
``num_bins = ceil(time_thres / hist_bin_size)``.
If both ``absorption`` and ``scattering`` are floats, then ``num_bands == 1``.
"""
if time_thres < hist_bin_size:
raise ValueError(
"`time_thres` must be greater than `hist_bin_size`. "
f"Found: hist_bin_size={hist_bin_size}, time_thres={time_thres}."
)
if room.dtype != source.dtype or source.dtype != mic_array.dtype:
raise ValueError(
"dtype of `room`, `source` and `mic_array` must match. "
f"Found: `room` ({room.dtype}), `source` ({source.dtype}) and "
f"`mic_array` ({mic_array.dtype})"
)
_validate_inputs(room, source, mic_array)
absorption = _adjust_coeff(absorption, "absorption").to(room.dtype)
scattering = _adjust_coeff(scattering, "scattering").to(room.dtype)
# Bring absorption and scattering to the same shape
if absorption.shape[0] == 1 and scattering.shape[0] > 1:
absorption = absorption.expand(scattering.shape)
if scattering.shape[0] == 1 and absorption.shape[0] > 1:
scattering = scattering.expand(absorption.shape)
if absorption.shape != scattering.shape:
raise ValueError(
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls. "
f"Inferred shapes absorption={absorption.shape} and scattering={scattering.shape}"
)
histograms = torch.ops.torchaudio.ray_tracing(
room,
source,
mic_array,
num_rays,
absorption,
scattering,
mic_radius,
sound_speed,
energy_thres,
time_thres,
hist_bin_size,
)
return histograms
......@@ -3,6 +3,8 @@
using namespace torchaudio::rir;
using DTYPE = double;
struct CollisionTestParam {
// Input
torch::Tensor origin;
......@@ -10,15 +12,15 @@ struct CollisionTestParam {
// Expected
torch::Tensor hit_point;
int next_wall_index;
float hit_distance;
DTYPE hit_distance;
};
CollisionTestParam par(
torch::ArrayRef<float> origin,
torch::ArrayRef<float> direction,
torch::ArrayRef<float> hit_point,
torch::ArrayRef<DTYPE> origin,
torch::ArrayRef<DTYPE> direction,
torch::ArrayRef<DTYPE> hit_point,
int next_wall_index,
float hit_distance) {
DTYPE hit_distance) {
auto dir = torch::tensor(direction);
return {
torch::tensor(origin),
......@@ -50,18 +52,22 @@ TEST_P(Simple3DRoomCollisionTest, CollisionTest3D) {
auto param = GetParam();
auto [hit_point, next_wall_index, hit_distance] =
find_collision_wall<float>(room, param.origin, param.direction);
find_collision_wall<DTYPE>(room, param.origin, param.direction);
EXPECT_EQ(param.next_wall_index, next_wall_index);
EXPECT_FLOAT_EQ(param.hit_distance, hit_distance);
EXPECT_TRUE(torch::allclose(
param.hit_point, hit_point, /*rtol*/ 1e-05, /*atol*/ 1e-07));
EXPECT_NEAR(
param.hit_point[0].item<DTYPE>(), hit_point[0].item<DTYPE>(), 1e-5);
EXPECT_NEAR(
param.hit_point[1].item<DTYPE>(), hit_point[1].item<DTYPE>(), 1e-5);
EXPECT_NEAR(
param.hit_point[2].item<DTYPE>(), hit_point[2].item<DTYPE>(), 1e-5);
}
#define ISQRT2 0.70710678118
INSTANTIATE_TEST_CASE_P(
Collision3DTests,
BasicCollisionTests,
Simple3DRoomCollisionTest,
::testing::Values(
// From 0
......@@ -100,3 +106,13 @@ INSTANTIATE_TEST_CASE_P(
par({.5, .5, 1}, {0.0, -1., -1.}, {.5, .0, .5}, 2, ISQRT2),
par({.5, .5, 1}, {0.0, 1.0, -1.}, {.5, 1., .5}, 3, ISQRT2),
par({.5, .5, 1}, {0.0, 0.0, -1.}, {.5, .5, .0}, 4, 1.0)));
INSTANTIATE_TEST_CASE_P(
CornerCollisionTest,
Simple3DRoomCollisionTest,
::testing::Values(
par({1, 1, 0}, {1., 1., 0.}, {1., 1., 0.}, 1, 0.0),
par({1, 1, 0}, {-1., 1., 0.}, {1., 1., 0.}, 3, 0.0),
par({1, 1, 1}, {1., 1., 1.}, {1., 1., 1.}, 1, 0.0),
par({1, 1, 1}, {-1., 1., 1.}, {1., 1., 1.}, 3, 0.0),
par({1, 1, 1}, {-1., -1., 1.}, {1., 1., 1.}, 5, 0.0)));
......@@ -412,6 +412,260 @@ class FunctionalTestImpl(TestBaseMixin):
self.assertEqual(torch_out, torch.tensor(np_out))
@parameterized.expand(
[
# both float
(0.1, 0.2, (2, 1, 2500)),
# Per-wall
((6,), 0.2, (2, 1, 2500)),
(0.1, (6,), (2, 1, 2500)),
((6,), (6,), (2, 1, 2500)),
# Per-band and per-wall
((3, 6), 0.2, (2, 3, 2500)),
(0.1, (5, 6), (2, 5, 2500)),
((7, 6), (7, 6), (2, 7, 2500)),
]
)
def test_ray_tracing_output_shape(self, abs_, scat_, expected_shape):
if isinstance(abs_, float):
absorption = abs_
else:
absorption = torch.rand(abs_, dtype=self.dtype)
if isinstance(scat_, float):
scattering = scat_
else:
scattering = torch.rand(scat_, dtype=self.dtype)
room_dim = torch.tensor([3, 4, 5], dtype=self.dtype)
mic_array = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=self.dtype)
source = torch.tensor([1, 2, 3], dtype=self.dtype)
num_rays = 100
hist = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
)
assert hist.shape == expected_shape
def test_ray_tracing_input_errors(self):
room = torch.tensor([3.0, 4.0, 5.0], dtype=self.dtype)
source = torch.tensor([0.0, 0.0, 0.0], dtype=self.dtype)
mic = torch.tensor([[1.0, 2.0, 3.0]], dtype=self.dtype)
# baseline. This should not raise
_ = F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10)
# invlaid room shape
for invalid in ([[4, 5]], [4, 5, 4, 5]):
invalid = torch.tensor(invalid, dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=invalid, source=source, mic_array=mic, num_rays=10)
error = str(cm.exception)
self.assertIn("`room` must be a 1D Tensor with 3 elements.", error)
self.assertIn(str(invalid.shape), error)
# invalid microphone shape
invalid = torch.tensor([[[3, 4]]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=invalid, num_rays=10)
error = str(cm.exception)
self.assertIn("`mic_array` must be a 2D Tensor with shape (num_channels, 3).", error)
self.assertIn(str(invalid.shape), error)
# incompatible dtypes
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room.to(torch.float64),
source=source.to(torch.float32),
mic_array=mic.to(torch.float32),
num_rays=10,
)
error = str(cm.exception)
self.assertIn("dtype of `room`, `source` and `mic_array` must match.", error)
self.assertIn("`room` (torch.float64)", error)
self.assertIn("`source` (torch.float32)", error)
self.assertIn("`mic_array` (torch.float32)", error)
# invalid time configuration
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
time_thres=10,
hist_bin_size=11,
)
error = str(cm.exception)
self.assertIn("`time_thres` must be greater than `hist_bin_size`.", error)
self.assertIn("hist_bin_size=11", error)
self.assertIn("time_thres=10", error)
# invalid absorption shape 1D
invalid_abs = torch.tensor([1, 2, 3], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=invalid_abs,
)
error = str(cm.exception)
self.assertIn("The shape of `absorption` must be (6,) when", error)
self.assertIn(str(invalid_abs.shape), error)
# invalid absorption shape 2D
invalid_abs = torch.tensor([[1, 2, 3]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_abs)
error = str(cm.exception)
self.assertIn("The shape of `absorption` must be (NUM_BANDS, 6) when", error)
self.assertIn(str(invalid_abs.shape), error)
# invalid scattering shape 1D
invalid_scat = torch.tensor([1, 2, 3], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
scattering=invalid_scat,
)
error = str(cm.exception)
self.assertIn("The shape of `scattering` must be (6,) when", error)
self.assertIn(str(invalid_scat.shape), error)
# invalid scattering shape 2D
invalid_scat = torch.tensor([[1, 2, 3]], dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_scat)
error = str(cm.exception)
self.assertIn("The shape of `scattering` must be (NUM_BANDS, 6) when", error)
self.assertIn(str(invalid_scat.shape), error)
# Invalid absorption value
for invalid_val in [-1.0, torch.tensor([i - 1.0 for i in range(6)])]:
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, absorption=invalid_val)
error = str(cm.exception)
self.assertIn("`absorption` must be non-negative`")
# Invalid scattering value
for invalid_val in [-1.0, torch.tensor([i - 1.0 for i in range(6)])]:
with self.assertRaises(ValueError) as cm:
F.ray_tracing(room=room, source=source, mic_array=mic, num_rays=10, scattering=invalid_val)
error = str(cm.exception)
self.assertIn("`scattering` must be non-negative`")
# incompatible scattering and absorption
abs_ = torch.zeros((7, 6), dtype=self.dtype)
scat = torch.zeros((5, 6), dtype=self.dtype)
with self.assertRaises(ValueError) as cm:
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=abs_,
scattering=scat,
)
error = str(cm.exception)
self.assertIn(
"`absorption` and `scattering` must be broadcastable to the same number of bands and walls", error
)
self.assertIn(f"absorption={abs_.shape}", error)
self.assertIn(f"scattering={scat.shape}", error)
# Make sure passing different shapes for absorption or scattering doesn't raise an error
# float and tensor
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=0.1,
scattering=torch.rand((5, 6), dtype=self.dtype),
)
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand((7, 6), dtype=self.dtype),
scattering=0.1,
)
# per-wall only and per-band + per-wall
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand(6, dtype=self.dtype),
scattering=torch.rand(7, 6, dtype=self.dtype),
)
F.ray_tracing(
room=room,
source=source,
mic_array=mic,
num_rays=10,
absorption=torch.rand(7, 6, dtype=self.dtype),
scattering=torch.rand(6, dtype=self.dtype),
)
def test_ray_tracing_per_band_per_wall_absorption(self):
"""Check that when the value of absorption and scattering are the same
across walls and frequency bands, the output histograms are:
- all equal across frequency bands
- equal to simply passing a float value instead of a (num_bands, D) or
(D,) tensor.
"""
room_dim = torch.tensor([20, 25, 5], dtype=self.dtype)
mic_array = torch.tensor([[2, 2, 0], [8, 8, 0]], dtype=self.dtype)
source = torch.tensor([7, 6, 0], dtype=self.dtype)
num_rays = 1_000
ABS, SCAT = 0.1, 0.2
hist_per_band_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=torch.full(fill_value=ABS, size=(7, 6), dtype=self.dtype),
scattering=torch.full(fill_value=SCAT, size=(7, 6), dtype=self.dtype),
)
hist_per_wall = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=torch.full(fill_value=ABS, size=(6,), dtype=self.dtype),
scattering=torch.full(fill_value=SCAT, size=(6,), dtype=self.dtype),
)
hist_single = F.ray_tracing(
room=room_dim,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=ABS,
scattering=SCAT,
)
self.assertEqual(hist_per_band_per_wall.shape, (2, 7, 2500))
self.assertEqual(hist_per_wall.shape, (2, 1, 2500))
self.assertEqual(hist_single.shape, (2, 1, 2500))
self.assertEqual(hist_single, hist_per_wall)
self.assertEqual(hist_single.expand(hist_per_band_per_wall.shape), hist_per_band_per_wall)
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
......
import math
import numpy as np
import torch
import torchaudio.prototype.functional as F
......@@ -9,11 +12,62 @@ if _mod_utils.is_module_available("pyroomacoustics"):
import pyroomacoustics as pra
def _pra_ray_tracing(
room_dim,
absorption,
scattering,
num_bands,
mic_array,
source,
num_rays,
energy_thres,
time_thres,
hist_bin_size,
mic_radius,
sound_speed,
):
walls = ["west", "east", "south", "north", "floor", "ceiling"]
absorption = absorption.T.tolist()
scattering = scattering.T.tolist()
freqs = 125 * 2 ** np.arange(num_bands)
room = pra.ShoeBox(
room_dim.tolist(),
ray_tracing=True,
materials={
wall: pra.Material(
energy_absorption={"coeffs": absorp, "center_freqs": freqs},
scattering={"coeffs": scat, "center_freqs": freqs},
)
for wall, absorp, scat in zip(walls, absorption, scattering)
},
air_absorption=False,
max_order=0, # Make sure PRA doesn't use the hybrid method (we just want ray tracing)
)
room.add_microphone_array(mic_array.T.tolist())
room.add_source(source.tolist())
room.set_ray_tracing(
n_rays=num_rays,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
receiver_radius=mic_radius,
)
room.set_sound_speed(sound_speed)
room.compute_rir()
hist_pra = np.array(room.rt_histograms, dtype=np.float32)[:, 0, 0]
# PRA continues the simulation beyond time threshold, but torchaudio does not.
num_bins = math.ceil(time_thres / hist_bin_size)
return hist_pra[:, :, :num_bins]
@skipIfNoModule("pyroomacoustics")
@skipIfNoRIR
class CompatibilityTest(PytorchTestCase):
dtype = torch.float64
# pyroomacoustics uses float for internal implementations.
dtype = torch.float32
device = torch.device("cpu")
@parameterized.expand([(1,), (4,)])
......@@ -91,3 +145,53 @@ class CompatibilityTest(PytorchTestCase):
expected[i, 0 : room.rir[i][0].shape[0]] = torch.from_numpy(room.rir[i][0])
actual = F.simulate_rir_ism(room_dim, source, mic_array, max_order, absorption)
self.assertEqual(expected, actual, atol=1e-3, rtol=1e-3)
@parameterized.expand(
[
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 130),
]
)
def test_ray_tracing_same_results_as_pyroomacoustics(self, room, source, mic_array, num_rays):
num_bands = 6
energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0
absorption = torch.full((num_bands, 6), 0.1, dtype=self.dtype)
scattering = torch.full((num_bands, 6), 0.4, dtype=self.dtype)
room = torch.tensor(room, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)
hist_pra = _pra_ray_tracing(
room,
absorption,
scattering,
num_bands,
mic_array,
source,
num_rays,
energy_thres,
time_thres,
hist_bin_size,
mic_radius,
sound_speed,
)
hist = F.ray_tracing(
room=room,
source=source,
mic_array=mic_array,
num_rays=num_rays,
absorption=absorption,
scattering=scattering,
sound_speed=sound_speed,
mic_radius=mic_radius,
energy_thres=energy_thres,
time_thres=time_thres,
hist_bin_size=hist_bin_size,
)
self.assertEqual(hist, hist_pra, atol=0.001, rtol=0.001)
......@@ -112,3 +112,42 @@ class TorchScriptConsistencyCPUOnlyTestImpl(TestBaseMixin):
F.simulate_rir_ism,
(room_dim, source, mic_array, max_order, absorption, None, 81, center_frequency, 343.0, 16000.0),
)
@parameterized.expand(
[
([20, 25, 30], [1, 10, 5], [[8, 8, 22]], 500), # 3D with 1 mic
]
)
def test_ray_tracing(self, room_dim, source, mic_array, num_rays):
num_walls = 4 if len(room_dim) == 2 else 6
num_bands = 3
absorption = torch.rand(num_bands, num_walls, dtype=torch.float32)
scattering = torch.rand(num_bands, num_walls, dtype=torch.float32)
energy_thres = 1e-7
time_thres = 10.0
hist_bin_size = 0.004
mic_radius = 0.5
sound_speed = 343.0
room_dim = torch.tensor(room_dim, dtype=self.dtype)
source = torch.tensor(source, dtype=self.dtype)
mic_array = torch.tensor(mic_array, dtype=self.dtype)
self._assert_consistency(
F.ray_tracing,
(
room_dim,
source,
mic_array,
num_rays,
absorption,
scattering,
mic_radius,
sound_speed,
energy_thres,
time_thres,
hist_bin_size,
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment