"tests/python/common/test_subgraph.py" did not exist on "f46080a4d1ebffb2aede1df0bf03a24276036b71"
Unverified Commit 92ded610 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor argument validation (#3632)

parent ff266b15
...@@ -110,58 +110,69 @@ def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length: ...@@ -110,58 +110,69 @@ def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length:
return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad) return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad)
def _validate_inputs( def _adjust_coeff(dim: int, coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor:
room: torch.Tensor, source: torch.Tensor, mic_array: torch.Tensor, absorption: Union[float, torch.Tensor] """Validates and converts absorption or scattering parameters to a tensor with appropriate shape
) -> torch.Tensor:
"""Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
Args: Args:
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents dim (int): The dimension of the simulation. 2 or 3.
three dimensions of the room. coeff (float or torch.Tensor): The absorption coefficients of wall materials.
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
absorption (float or torch.Tensor): The absorption coefficients of wall materials.
If the dtype is ``float``, the absorption coefficient is identical for all walls and If the dtype is ``float``, the absorption coefficient is identical for all walls and
all frequencies. all frequencies.
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent
absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``, If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`,
and ``"ceiling"``, respectively. where the values represent absorption coefficients of ``"west"``, ``"east"``,
If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands. ``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`,
where 7 represents the number of octave bands.
Returns: Returns:
(torch.Tensor): The absorption Tensor. The shape is `(1, 6)` for single octave band case, (torch.Tensor): The expanded coefficient.
or `(7, 6)` for multi octave band case. The shape is `(1, 2*dim)` for single octave band case, and
`(7, 2*dim)` for multi octave band case.
""" """
if room.ndim != 1: num_walls = 2 * dim
raise ValueError(f"room must be a 1D Tensor. Found {room.shape}.") if isinstance(coeffs, float):
D = room.shape[0] return torch.full((1, num_walls), coeffs)
if D != 3: if isinstance(coeffs, Tensor):
raise ValueError(f"room must be a 3D room. Found {room.shape}.") if coeffs.ndim == 1:
num_wall = 6 if coeffs.numel() != num_walls:
if source.shape[0] != D: raise ValueError(
raise ValueError(f"The shape of source must be `(3,)`. Found {source.shape}") f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
if mic_array.ndim != 2: f"Found the shape {coeffs.shape}."
raise ValueError(f"mic_array must be a 2D Tensor. Found {mic_array.shape}.") )
if mic_array.shape[1] != D: return coeffs.unsqueeze(0)
raise ValueError(f"The second dimension of mic_array must be 3. Found {mic_array.shape}.") if coeffs.ndim == 2:
if isinstance(absorption, float): if coeffs.shape != (7, num_walls):
absorption = torch.ones(1, num_wall) * absorption raise ValueError(
elif isinstance(absorption, Tensor) and absorption.ndim == 1: f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
if absorption.shape[0] != num_wall: f"Found the shape {coeffs.shape}."
raise ValueError( )
"The shape of absorption must be `(6,)` if it is a 1D Tensor." f"Found the shape {absorption.shape}." return coeffs
) raise TypeError(f"`{name}` must be float or Tensor.")
absorption = absorption.unsqueeze(0)
elif isinstance(absorption, Tensor) and absorption.ndim == 2:
if absorption.shape != (7, num_wall): def _validate_inputs(
raise ValueError( dim: int,
"The shape of absorption must be `(7, 6)` if it is a 2D Tensor." room: torch.Tensor,
f"Found the shape of room is {D} and shape of absorption is {absorption.shape}." source: torch.Tensor,
) mic_array: torch.Tensor,
absorption = absorption ):
else: """Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
absorption = absorption
return absorption Args:
dim (int): The dimension of the simulation. 2 or 3.
room (torch.Tensor): The size of the room. width, length (and height)
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(dim,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, dim)`.
"""
if not (room.ndim == 1 and room.numel() == dim):
raise ValueError(f"`room` must be a 1D Tensor with {dim} elements. Found {room.shape}.")
if not (source.ndim == 1 and source.numel() == dim):
raise ValueError(f"`source` must be 1D Tensor with {dim} elements. Found {source.shape}.")
if not (mic_array.ndim == 2 and mic_array.shape[1] == dim):
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, {dim}). Found {mic_array.shape}.")
def simulate_rir_ism( def simulate_rir_ism(
...@@ -220,7 +231,8 @@ def simulate_rir_ism( ...@@ -220,7 +231,8 @@ def simulate_rir_ism(
of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``. of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``.
Users need to tune the values of ``absorption`` to the corresponding frequencies. Users need to tune the values of ``absorption`` to the corresponding frequencies.
""" """
absorption = _validate_inputs(room, source, mic_array, absorption) _validate_inputs(3, room, source, mic_array)
absorption = _adjust_coeff(3, absorption, "absorption")
img_location, att = _compute_image_sources(room, source, max_order, absorption) img_location, att = _compute_image_sources(room, source, max_order, absorption)
# compute distances between image sources and microphones # compute distances between image sources and microphones
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment