Unverified Commit 00efbe61 authored by cpuhrsch's avatar cpuhrsch Committed by GitHub
Browse files

Merge pull request #105 from jamarshon/T44497670

Migrate audio transform computations into functional.py
parents acdedc4a ec0b29f5
import numpy as np
import torch
__all__ = [
'scale',
'pad_trim',
'downmix_mono',
'LC2CL',
'spectrogram',
'create_fb_matrix',
'mel_scale',
'spectrogram_to_DB',
'create_dct',
'MFCC',
'BLC2CBL',
'mu_law_encoding',
'mu_law_expanding'
]
def scale(tensor, factor):
# type: (Tensor, int) -> Tensor
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
factor (int): Maximum value of input tensor
Outputs:
Tensor: Scaled by the scale factor
"""
if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32)
return tensor / factor
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor
"""Pad/Trim a 2d-Tensor (Signal or Labels)
Inputs:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
ch_dim (int): Dimension of channel (not size)
max_len (int): Length to which the tensor will be padded
len_dim (int): Dimension of length (not size)
fill_value (float): Value to fill in
Outputs:
Tensor: Padded/trimmed tensor
"""
if max_len > tensor.size(len_dim):
# tuple of (padding_left, padding_right, padding_top, padding_bottom)
# so pad similar to append (aka only right/bottom) and do not pad
# the length dimension. assumes equal sizes of padding.
padding = [max_len - tensor.size(len_dim)
if (i % 2 == 1) and (i // 2 != len_dim)
else 0
for i in range(4)]
with torch.no_grad():
tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value)
elif max_len < tensor.size(len_dim):
tensor = tensor.narrow(len_dim, 0, max_len)
return tensor
def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
"""Downmix any stereo signals to mono.
Inputs:
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
ch_dim (int): Dimension of channel (not size)
Outputs:
Tensor: Mono signal
"""
if not tensor.dtype.is_floating_point:
tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, ch_dim, True)
return tensor
def LC2CL(tensor):
# type: (Tensor) -> Tensor
"""Permute a 2d tensor from samples (n x c) to (c x n)
Inputs:
tensor (Tensor): Tensor of audio signal with shape (LxC)
Outputs:
Tensor: Tensor of audio signal with shape (CxL)
"""
return tensor.transpose(0, 1).contiguous()
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
"""Create a spectrogram from a raw audio signal
Inputs:
sig (Tensor): Tensor of audio of size (c, n)
pad (int): two sided padding of signal
window (Tensor): window_tensor
n_fft (int): size of fft
hop (int): length of hop between STFT windows
ws (int): window size
power (int > 0 ) : Exponent for the magnitude spectrogram,
e.g., 1 for energy, 2 for power, etc.
normalize (bool) : whether to normalize by magnitude after stft
Outputs:
Tensor: channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
"""
assert sig.dim() == 2
if pad > 0:
with torch.no_grad():
sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
window = window.to(sig.device)
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(sig, n_fft, hop, ws,
window, center=True,
normalized=False, onesided=True,
pad_mode='reflect').transpose(1, 2)
if normalize:
spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f
def create_fb_matrix(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor
""" Create a frequency bin conversion matrix.
Inputs:
n_stft (int): number of filter banks from spectrogram
f_min (float): minimum frequency
f_max (float): maximum frequency
n_mels (int): number of mel bins
Outputs:
Tensor: triangular filter banks (fb matrix)
"""
def _hertz_to_mel(f):
# type: (float) -> Tensor
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(mel):
# type: (Tensor) -> Tensor
return 700. * (10**(mel / 2595.) - 1.)
# get stft freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft)
# calculate mel freq bins
m_min = 0. if f_min == 0 else _hertz_to_mel(f_min)
m_max = _hertz_to_mel(f_max)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
""" This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
Inputs:
spec_f (Tensor): normal STFT
f_min (float): minimum frequency
f_max (float): maximum frequency
n_mels (int): number of mel bins
fb (Optional[Tensor]): triangular filter banks (fb matrix)
Outputs:
Tuple[Tensor, Tensor]: triangular filter banks (fb matrix) and mel frequency STFT
"""
if fb is None:
fb = create_fb_matrix(spec_f.size(2), f_min, f_max, n_mels).to(spec_f.device)
else:
# need to ensure same device for dot product
fb = fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return fb, spec_m
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Inputs:
spec (Tensor): normal STFT
multiplier (float): use 10. for power and 20. for amplitude
amin (float): number to clamp spec
db_multiplier (float): log10(max(reference value and amin))
top_db (Optional[float]): minimum negative cut-off in decibels. A reasonable number
is 80.
Outputs:
Tensor: spectrogram in DB
"""
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
spec_db -= multiplier * db_multiplier
if top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - top_db))
return spec_db
def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, string) -> Tensor
"""
Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
normalized depending on norm
Inputs:
n_mfcc (int) : number of mfc coefficients to retain
n_mels (int): number of MEL bins
norm (string) : norm to use
Outputs:
Tensor: The transformation matrix, to be right-multiplied to row-wise data.
"""
outdim = n_mfcc
dim = n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = np.arange(dim)
k = np.arange(outdim)[:, np.newaxis]
dct = np.cos(np.pi / dim * (n + 0.5) * k)
if norm == 'ortho':
dct[0] *= 1.0 / np.sqrt(2)
dct *= np.sqrt(2.0 / dim)
else:
dct *= 2
return torch.Tensor(dct.T)
def MFCC(sig, mel_spect, log_mels, s2db, dct_mat):
# type: (Tensor, MelSpectrogram, bool, SpectrogramToDB, Tensor) -> Tensor
"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Inputs:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
mel_spect (MelSpectrogram): melspectrogram of sig
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
s2db (SpectrogramToDB): a SpectrogramToDB instance
dct_mat (Tensor): The transformation matrix (dct matrix), to be
right-multiplied to row-wise data
Outputs:
Tensor: Mel-frequency cepstrum coefficients
"""
if log_mels:
log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset)
else:
mel_spect = s2db(mel_spect)
mfcc = torch.matmul(mel_spect, dct_mat.to(mel_spect.device))
return mfcc
def BLC2CBL(tensor):
# type: (Tensor) -> Tensor
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
Inputs:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)
Outputs:
Tensor: Tensor of spectrogram with shape (CxBxL)
"""
return tensor.permute(2, 0, 1).contiguous()
def mu_law_encoding(x, qc):
# type: (Tensor/ndarray, int) -> Tensor/ndarray
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1
Inputs:
x (Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels)
Outputs:
Tensor: Input after mu-law companding
"""
mu = qc - 1.
if isinstance(x, np.ndarray):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, torch.Tensor):
if not x.dtype.is_floating_point:
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu *
torch.abs(x)) / torch.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
return x_mu
def mu_law_expanding(x_mu, qc):
# type: (Tensor/ndarray, int) -> Tensor/ndarray
"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1.
Inputs:
x_mu (Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels)
Outputs:
Tensor: Input after decoding
"""
mu = qc - 1.
if isinstance(x_mu, np.ndarray):
x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, torch.Tensor):
if not x_mu.dtype.is_floating_point:
x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x
...@@ -2,6 +2,7 @@ from __future__ import division, print_function ...@@ -2,6 +2,7 @@ from __future__ import division, print_function
from warnings import warn from warnings import warn
import torch import torch
import numpy as np import numpy as np
from . import functional as F
class Compose(object): class Compose(object):
...@@ -57,17 +58,14 @@ class Scale(object): ...@@ -57,17 +58,14 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0) Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
""" """
if not tensor.dtype.is_floating_point: return F.scale(tensor, self.factor)
tensor = tensor.to(torch.float32)
return tensor / self.factor
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class PadTrim(object): class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels) """Pad/Trim a 2d-Tensor (Signal or Labels)
Args: Args:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n) tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
...@@ -88,18 +86,7 @@ class PadTrim(object): ...@@ -88,18 +86,7 @@ class PadTrim(object):
Tensor: (c x n) or (n x c) Tensor: (c x n) or (n x c)
""" """
assert tensor.size(self.ch_dim) < 128, \ return F.pad_trim(tensor, self.ch_dim, self.max_len, self.len_dim, self.fill_value)
"Too many channels ({}) detected, see channels_first param.".format(tensor.size(self.ch_dim))
if self.max_len > tensor.size(self.len_dim):
padding = [self.max_len - tensor.size(self.len_dim)
if (i % 2 == 1) and (i // 2 != self.len_dim)
else 0
for i in range(4)]
with torch.no_grad():
tensor = torch.nn.functional.pad(tensor, padding, "constant", self.fill_value)
elif self.max_len < tensor.size(self.len_dim):
tensor = tensor.narrow(self.len_dim, 0, self.max_len)
return tensor
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(max_len={0})'.format(self.max_len) return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)
...@@ -122,11 +109,7 @@ class DownmixMono(object): ...@@ -122,11 +109,7 @@ class DownmixMono(object):
self.ch_dim = int(not channels_first) self.ch_dim = int(not channels_first)
def __call__(self, tensor): def __call__(self, tensor):
if not tensor.dtype.is_floating_point: return F.downmix_mono(tensor, self.ch_dim)
tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, self.ch_dim, True)
return tensor
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
...@@ -145,7 +128,7 @@ class LC2CL(object): ...@@ -145,7 +128,7 @@ class LC2CL(object):
Returns: Returns:
tensor (Tensor): Tensor of audio signal with shape (CxL) tensor (Tensor): Tensor of audio signal with shape (CxL)
""" """
return tensor.transpose(0, 1).contiguous() return F.LC2CL(tensor)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
...@@ -196,22 +179,8 @@ class Spectrogram(object): ...@@ -196,22 +179,8 @@ class Spectrogram(object):
by 2 plus 1. by 2 plus 1.
""" """
assert sig.dim() == 2 return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop,
self.ws, self.power, self.normalize)
if self.pad > 0:
with torch.no_grad():
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
self.window = self.window.to(sig.device)
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=True,
normalized=False, onesided=True,
pad_mode='reflect').transpose(1, 2)
if self.normalize:
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(self.power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f
def F2M(*args, **kwargs): def F2M(*args, **kwargs):
...@@ -236,47 +205,13 @@ class MelScale(object): ...@@ -236,47 +205,13 @@ class MelScale(object):
self.sr = sr self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2 self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min self.f_min = f_min
self.fb = self._create_fb_matrix(n_stft) if n_stft is not None else n_stft self.fb = F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels) if n_stft is not None else n_stft
def __call__(self, spec_f): def __call__(self, spec_f):
if self.fb is None: self.fb, spec_m = F.mel_scale(spec_f, self.f_min, self.f_max, self.n_mels, self.fb)
self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
else:
# need to ensure same device for dot product
self.fb = self.fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m return spec_m
def _create_fb_matrix(self, n_stft):
""" Create a frequency bin conversion matrix.
Args:
n_stft (int): number of filter banks from spectrogram
"""
# get stft freq bins
stft_freqs = torch.linspace(self.f_min, self.f_max, n_stft)
# calculate mel freq bins
m_min = 0. if self.f_min == 0 else self._hertz_to_mel(self.f_min)
m_max = self._hertz_to_mel(self.f_max)
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = self._mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
def _hertz_to_mel(self, f):
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(self, mel):
return 700. * (10**(mel / 2595.) - 1.)
class SpectrogramToDB(object): class SpectrogramToDB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. """Turns a spectrogram from the power/amplitude scale to the decibel scale.
...@@ -304,12 +239,7 @@ class SpectrogramToDB(object): ...@@ -304,12 +239,7 @@ class SpectrogramToDB(object):
def __call__(self, spec): def __call__(self, spec):
# numerically stable implementation from librosa # numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html # https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
spec_db = self.multiplier * torch.log10(torch.clamp(spec, min=self.amin)) return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
spec_db -= self.multiplier * self.db_multiplier
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - self.top_db))
return spec_db
class MFCC(object): class MFCC(object):
...@@ -352,29 +282,9 @@ class MFCC(object): ...@@ -352,29 +282,9 @@ class MFCC(object):
if self.n_mfcc > self.MelSpectrogram.n_mels: if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins') raise ValueError('Cannot select more MFCC coefficients than # mel bins')
self.dct_mat = self.create_dct() self.dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
self.log_mels = log_mels self.log_mels = log_mels
def create_dct(self):
"""
Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
normalized depending on self.norm
Returns:
The transformation matrix, to be right-multiplied to row-wise data.
"""
outdim = self.n_mfcc
dim = self.MelSpectrogram.n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = np.arange(dim)
k = np.arange(outdim)[:, np.newaxis]
dct = np.cos(np.pi / dim * (n + 0.5) * k)
if self.norm == 'ortho':
dct[0] *= 1.0 / np.sqrt(2)
dct *= np.sqrt(2.0 / dim)
else:
dct *= 2
return torch.Tensor(dct.T)
def __call__(self, sig): def __call__(self, sig):
""" """
Args: Args:
...@@ -385,14 +295,7 @@ class MFCC(object): ...@@ -385,14 +295,7 @@ class MFCC(object):
is unchanged, hops is the number of hops, and n_mels is the is unchanged, hops is the number of hops, and n_mels is the
number of mel bins. number of mel bins.
""" """
mel_spect = self.MelSpectrogram(sig) return F.MFCC(sig, self.MelSpectrogram(sig), self.log_mels, self.s2db, self.dct_mat)
if self.log_mels:
log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset)
else:
mel_spect = self.s2db(mel_spect)
mfcc = torch.matmul(mel_spect, self.dct_mat.to(mel_spect.device))
return mfcc
class MelSpectrogram(object): class MelSpectrogram(object):
...@@ -475,8 +378,7 @@ class BLC2CBL(object): ...@@ -475,8 +378,7 @@ class BLC2CBL(object):
tensor (Tensor): Tensor of spectrogram with shape (CxBxL) tensor (Tensor): Tensor of spectrogram with shape (CxBxL)
""" """
return F.BLC2CBL(tensor)
return tensor.permute(2, 0, 1).contiguous()
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
...@@ -507,18 +409,7 @@ class MuLawEncoding(object): ...@@ -507,18 +409,7 @@ class MuLawEncoding(object):
x_mu (LongTensor or ndarray) x_mu (LongTensor or ndarray)
""" """
mu = self.qc - 1. return F.mu_law_encoding(x, self.qc)
if isinstance(x, np.ndarray):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, torch.Tensor):
if not x.dtype.is_floating_point:
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu *
torch.abs(x)) / torch.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
return x_mu
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
...@@ -549,17 +440,7 @@ class MuLawExpanding(object): ...@@ -549,17 +440,7 @@ class MuLawExpanding(object):
x (FloatTensor or ndarray) x (FloatTensor or ndarray)
""" """
mu = self.qc - 1. return F.mu_law_expanding(x_mu, self.qc)
if isinstance(x_mu, np.ndarray):
x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, torch.Tensor):
if not x_mu.dtype.is_floating_point:
x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment