Unverified Commit 2bce4898 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing transforms (#487)

* add inline typing

* update type hinting

* update typing

* sync the docstrings typing

* reorder imports, add typing to missing method

* add missing parenthesis
parent d069fb9f
# -*- coding: utf-8 -*-
from warnings import warn
import math
from typing import Callable, Optional
from warnings import warn
import torch
from typing import Optional
from torch import Tensor
from torchaudio import functional as F
from torchaudio.compliance import kaldi
......@@ -35,19 +37,25 @@ class Spectrogram(torch.nn.Module):
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self, n_fft=400, win_length=None, hop_length=None,
pad=0, window_fn=torch.hann_window,
power=2., normalized=False, wkwargs=None):
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
......@@ -60,13 +68,13 @@ class Spectrogram(torch.nn.Module):
self.power = power
self.normalized = normalized
def forward(self, waveform):
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Dimension (..., freq, time), where freq is
Tensor: Dimension (..., freq, time), where freq is
``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
......@@ -96,12 +104,12 @@ class GriffinLim(torch.nn.Module):
n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
momentum (float, optional): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
......@@ -111,9 +119,18 @@ class GriffinLim(torch.nn.Module):
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
'length', 'momentum', 'rand_init']
def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None,
window_fn=torch.hann_window, power=2., normalized=False, wkwargs=None,
momentum=0.99, length=None, rand_init=True):
def __init__(self,
n_fft: int = 400,
n_iter: int = 32,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None,
momentum: float = 0.99,
length: Optional[int] = None,
rand_init: bool = True) -> None:
super(GriffinLim, self).__init__()
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
......@@ -131,8 +148,16 @@ class GriffinLim(torch.nn.Module):
self.momentum = momentum / (1 + momentum)
self.rand_init = rand_init
def forward(self, S):
return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``.
Returns:
Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
......@@ -151,7 +176,7 @@ class AmplitudeToDB(torch.nn.Module):
"""
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
def __init__(self, stype='power', top_db=None):
def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
super(AmplitudeToDB, self).__init__()
self.stype = stype
if top_db is not None and top_db < 0:
......@@ -162,15 +187,15 @@ class AmplitudeToDB(torch.nn.Module):
self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
r"""Numerically stable implementation from Librosa.
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args:
x (torch.Tensor): Input tensor before being converted to decibel scale.
x (Tensor): Input tensor before being converted to decibel scale.
Returns:
torch.Tensor: Output tensor in decibel scale.
Tensor: Output tensor in decibel scale.
"""
return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
......@@ -191,7 +216,12 @@ class MelScale(torch.nn.Module):
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
def __init__(self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None) -> None:
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
......@@ -204,13 +234,13 @@ class MelScale(torch.nn.Module):
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)
def forward(self, specgram):
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time).
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
......@@ -242,20 +272,28 @@ class InverseMelScale(torch.nn.Module):
Args:
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
n_mels (int): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int): Maximum number of optimization iterations.
tolerance_loss (float): Value of loss to stop optimization at.
tolerance_change (float): Difference in losses to stop optimization at.
sgdargs (dict): Arguments for the SGD optimizer.
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
"""
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs']
def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000,
tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None):
def __init__(self,
n_stft: int,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None) -> None:
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
......@@ -271,13 +309,13 @@ class InverseMelScale(torch.nn.Module):
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)
def forward(self, melspec):
def forward(self, melspec: Tensor) -> Tensor:
r"""
Args:
melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
Returns:
torch.Tensor: Linear scale spectrogram of size (..., freq, time)
Tensor: Linear scale spectrogram of size (..., freq, time)
"""
# pack batch
shape = melspec.size()
......@@ -335,7 +373,7 @@ class MelSpectrogram(torch.nn.Module):
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
......@@ -345,8 +383,17 @@ class MelSpectrogram(torch.nn.Module):
"""
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
def __init__(self,
sample_rate: int = 16000,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
f_min: float = 0.,
f_max: Optional[float] = None,
pad: int = 0,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
......@@ -362,13 +409,13 @@ class MelSpectrogram(torch.nn.Module):
normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
def forward(self, waveform):
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
......@@ -392,12 +439,17 @@ class MFCC(torch.nn.Module):
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``)
log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
"""
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
melkwargs=None):
def __init__(self,
sample_rate: int = 16000,
n_mfcc: int = 40,
dct_type: int = 2,
norm: str = 'ortho',
log_mels: bool = False,
melkwargs: Optional[dict] = None) -> None:
super(MFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
......@@ -420,13 +472,13 @@ class MFCC(torch.nn.Module):
self.register_buffer('dct_mat', dct_mat)
self.log_mels = log_mels
def forward(self, waveform):
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
"""
# pack batch
......@@ -461,17 +513,17 @@ class MuLawEncoding(torch.nn.Module):
"""
__constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256):
def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawEncoding, self).__init__()
self.quantization_channels = quantization_channels
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
r"""
Args:
x (torch.Tensor): A signal to be encoded.
x (Tensor): A signal to be encoded.
Returns:
x_mu (torch.Tensor): An encoded signal.
x_mu (Tensor): An encoded signal.
"""
return F.mu_law_encoding(x, self.quantization_channels)
......@@ -488,17 +540,17 @@ class MuLawDecoding(torch.nn.Module):
"""
__constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256):
def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawDecoding, self).__init__()
self.quantization_channels = quantization_channels
def forward(self, x_mu):
def forward(self, x_mu: Tensor) -> Tensor:
r"""
Args:
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded.
x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
Returns:
torch.Tensor: The signal decoded.
Tensor: The signal decoded.
"""
return F.mu_law_decoding(x_mu, self.quantization_channels)
......@@ -512,19 +564,22 @@ class Resample(torch.nn.Module):
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
"""
def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None:
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.resampling_method = resampling_method
def forward(self, waveform):
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
......@@ -550,11 +605,11 @@ class ComplexNorm(torch.nn.Module):
"""
__constants__ = ['power']
def __init__(self, power=1.0):
def __init__(self, power: float = 1.0) -> None:
super(ComplexNorm, self).__init__()
self.power = power
def forward(self, complex_tensor):
def forward(self, complex_tensor: Tensor) -> Tensor:
r"""
Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.
......@@ -576,18 +631,18 @@ class ComputeDeltas(torch.nn.Module):
"""
__constants__ = ['win_length']
def __init__(self, win_length=5, mode="replicate"):
def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
super(ComputeDeltas, self).__init__()
self.win_length = win_length
self.mode = mode
def forward(self, specgram):
def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time).
specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time).
Tensor: Tensor of deltas of dimension (..., freq, time).
"""
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
......@@ -603,7 +658,10 @@ class TimeStretch(torch.nn.Module):
"""
__constants__ = ['fixed_rate']
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
def __init__(self,
hop_length: Optional[int] = None,
n_freq: int = 201,
fixed_rate: Optional[float] = None) -> None:
super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate
......@@ -612,8 +670,7 @@ class TimeStretch(torch.nn.Module):
hop_length = hop_length if hop_length is not None else n_fft // 2
self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
......@@ -621,7 +678,7 @@ class TimeStretch(torch.nn.Module):
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
Returns:
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
......@@ -648,27 +705,28 @@ class Fade(torch.nn.Module):
fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
"half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
"""
def __init__(self, fade_in_len=0, fade_out_len=0, fade_shape="linear"):
def __init__(self,
fade_in_len: int = 0,
fade_out_len: int = 0,
fade_shape: str = "linear") -> None:
super(Fade, self).__init__()
self.fade_in_len = fade_in_len
self.fade_out_len = fade_out_len
self.fade_shape = fade_shape
def forward(self, waveform):
# type: (Tensor) -> Tensor
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension (..., time).
"""
waveform_length = waveform.size()[-1]
return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform
def _fade_in(self, waveform_length):
# type: (int) -> Tensor
def _fade_in(self, waveform_length: int) -> Tensor:
fade = torch.linspace(0, 1, self.fade_in_len)
ones = torch.ones(waveform_length - self.fade_in_len)
......@@ -689,8 +747,7 @@ class Fade(torch.nn.Module):
return torch.cat((fade, ones)).clamp_(0, 1)
def _fade_out(self, waveform_length):
# type: (int) -> Tensor
def _fade_out(self, waveform_length: int) -> Tensor:
fade = torch.linspace(0, 1, self.fade_out_len)
ones = torch.ones(waveform_length - self.fade_out_len)
......@@ -722,22 +779,21 @@ class _AxisMasking(torch.nn.Module):
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
def __init__(self, mask_param, axis, iid_masks):
def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
super(_AxisMasking, self).__init__()
self.mask_param = mask_param
self.axis = axis
self.iid_masks = iid_masks
def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (..., freq, time).
specgram (Tensor): Tensor of dimension (..., freq, time).
mask_value (float): Value to assign to the masked columns.
Returns:
torch.Tensor: Masked spectrogram of dimensions (..., freq, time).
Tensor: Masked spectrogram of dimensions (..., freq, time).
"""
# if iid_masks flag marked and specgram has a batch dimension
......@@ -757,7 +813,7 @@ class FrequencyMasking(_AxisMasking):
the examples/channels in the batch. (Default: ``False``)
"""
def __init__(self, freq_mask_param, iid_masks=False):
def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
......@@ -771,7 +827,7 @@ class TimeMasking(_AxisMasking):
the examples/channels in the batch. (Default: ``False``)
"""
def __init__(self, time_mask_param, iid_masks=False):
def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
......@@ -786,7 +842,7 @@ class Vol(torch.nn.Module):
gain_type (str, optional): Type of gain. One of: ‘amplitude’, ‘power’, ‘db’ (Default: ``"amplitude"``)
"""
def __init__(self, gain, gain_type='amplitude'):
def __init__(self, gain: float, gain_type: str = 'amplitude'):
super(Vol, self).__init__()
self.gain = gain
self.gain_type = gain_type
......@@ -794,14 +850,13 @@ class Vol(torch.nn.Module):
if gain_type in ['amplitude', 'power'] and gain < 0:
raise ValueError("If gain_type = amplitude or power, gain must be positive.")
def forward(self, waveform):
# type: (Tensor) -> Tensor
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension (..., time).
"""
if self.gain_type == "amplitude":
waveform = waveform * self.gain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment