"torchvision/vscode:/vscode.git/clone" did not exist on "4eb9f6601b62ff95e81fbdb5d574ebcf82b530ef"
Unverified Commit 92ded610 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor argument validation (#3632)

parent ff266b15
......@@ -110,58 +110,69 @@ def _frac_delay(delay: torch.Tensor, delay_i: torch.Tensor, delay_filter_length:
return torch.special.sinc(n - delay) * _hann(n - delay, 2 * pad)
def _validate_inputs(
room: torch.Tensor, source: torch.Tensor, mic_array: torch.Tensor, absorption: Union[float, torch.Tensor]
) -> torch.Tensor:
"""Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
def _adjust_coeff(dim: int, coeffs: Union[float, torch.Tensor], name: str) -> torch.Tensor:
"""Validates and converts absorption or scattering parameters to a tensor with appropriate shape
Args:
room (torch.Tensor): Room coordinates. The shape of `room` must be `(3,)` which represents
three dimensions of the room.
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(3,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, 3)`.
absorption (float or torch.Tensor): The absorption coefficients of wall materials.
dim (int): The dimension of the simulation. 2 or 3.
coeff (float or torch.Tensor): The absorption coefficients of wall materials.
If the dtype is ``float``, the absorption coefficient is identical for all walls and
all frequencies.
If ``absorption`` is a 1D Tensor, the shape must be `(6,)`, where the values represent
absorption coefficients of ``"west"``, ``"east"``, ``"south"``, ``"north"``, ``"floor"``,
and ``"ceiling"``, respectively.
If ``absorption`` is a 2D Tensor, the shape must be `(7, 6)`, where 7 represents the number of octave bands.
If ``absorption`` is a 1D Tensor, the shape must be `(2*dim,)`,
where the values represent absorption coefficients of ``"west"``, ``"east"``,
``"south"``, ``"north"``, ``"floor"``, and ``"ceiling"``, respectively.
If ``absorption`` is a 2D Tensor, the shape must be `(7, 2*dim)`,
where 7 represents the number of octave bands.
Returns:
(torch.Tensor): The absorption Tensor. The shape is `(1, 6)` for single octave band case,
or `(7, 6)` for multi octave band case.
(torch.Tensor): The expanded coefficient.
The shape is `(1, 2*dim)` for single octave band case, and
`(7, 2*dim)` for multi octave band case.
"""
if room.ndim != 1:
raise ValueError(f"room must be a 1D Tensor. Found {room.shape}.")
D = room.shape[0]
if D != 3:
raise ValueError(f"room must be a 3D room. Found {room.shape}.")
num_wall = 6
if source.shape[0] != D:
raise ValueError(f"The shape of source must be `(3,)`. Found {source.shape}")
if mic_array.ndim != 2:
raise ValueError(f"mic_array must be a 2D Tensor. Found {mic_array.shape}.")
if mic_array.shape[1] != D:
raise ValueError(f"The second dimension of mic_array must be 3. Found {mic_array.shape}.")
if isinstance(absorption, float):
absorption = torch.ones(1, num_wall) * absorption
elif isinstance(absorption, Tensor) and absorption.ndim == 1:
if absorption.shape[0] != num_wall:
num_walls = 2 * dim
if isinstance(coeffs, float):
return torch.full((1, num_walls), coeffs)
if isinstance(coeffs, Tensor):
if coeffs.ndim == 1:
if coeffs.numel() != num_walls:
raise ValueError(
"The shape of absorption must be `(6,)` if it is a 1D Tensor." f"Found the shape {absorption.shape}."
f"The shape of `{name}` must be ({num_walls},) when it is a 1D Tensor."
f"Found the shape {coeffs.shape}."
)
absorption = absorption.unsqueeze(0)
elif isinstance(absorption, Tensor) and absorption.ndim == 2:
if absorption.shape != (7, num_wall):
return coeffs.unsqueeze(0)
if coeffs.ndim == 2:
if coeffs.shape != (7, num_walls):
raise ValueError(
"The shape of absorption must be `(7, 6)` if it is a 2D Tensor."
f"Found the shape of room is {D} and shape of absorption is {absorption.shape}."
f"The shape of `{name}` must be (7, {num_walls}) when it is a 2D Tensor."
f"Found the shape {coeffs.shape}."
)
absorption = absorption
else:
absorption = absorption
return absorption
return coeffs
raise TypeError(f"`{name}` must be float or Tensor.")
def _validate_inputs(
dim: int,
room: torch.Tensor,
source: torch.Tensor,
mic_array: torch.Tensor,
):
"""Validate dimensions of input arguments, and normalize different kinds of absorption into the same dimension.
Args:
dim (int): The dimension of the simulation. 2 or 3.
room (torch.Tensor): The size of the room. width, length (and height)
source (torch.Tensor): Sound source coordinates. Tensor with dimensions `(dim,)`.
mic_array (torch.Tensor): Microphone coordinates. Tensor with dimensions `(channel, dim)`.
"""
if not (room.ndim == 1 and room.numel() == dim):
raise ValueError(f"`room` must be a 1D Tensor with {dim} elements. Found {room.shape}.")
if not (source.ndim == 1 and source.numel() == dim):
raise ValueError(f"`source` must be 1D Tensor with {dim} elements. Found {source.shape}.")
if not (mic_array.ndim == 2 and mic_array.shape[1] == dim):
raise ValueError(f"mic_array must be a 2D Tensor with shape (num_channels, {dim}). Found {mic_array.shape}.")
def simulate_rir_ism(
......@@ -220,7 +231,8 @@ def simulate_rir_ism(
of octave bands are fixed to ``[125.0, 250.0, 500.0, 1000.0, 2000.0, 4000.0, 8000.0]``.
Users need to tune the values of ``absorption`` to the corresponding frequencies.
"""
absorption = _validate_inputs(room, source, mic_array, absorption)
_validate_inputs(3, room, source, mic_array)
absorption = _adjust_coeff(3, absorption, "absorption")
img_location, att = _compute_image_sources(room, source, max_order, absorption)
# compute distances between image sources and microphones
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment