Commit 595dc5d3 authored by John Reese's avatar John Reese Committed by Facebook GitHub Bot
Browse files

[black][codemod] formatting changes from black 22.3.0

Summary:
Applies the black-fbsource codemod with the new build of pyfmt.

paintitblack

Reviewed By: lisroach

Differential Revision: D36324783

fbshipit-source-id: 280c09e88257e5e569ab729691165d8dedd767bc
parent 9877f544
......@@ -65,7 +65,7 @@ def generate_statistics(samples):
if idx % 100 == 0:
logger.info(f"Processed {idx}")
return E_x, (E_x_2 - E_x ** 2) ** 0.5
return E_x, (E_x_2 - E_x**2) ** 0.5
def get_dataset(args):
......
......@@ -120,7 +120,7 @@ class ApplyKmeans(object):
def __init__(self, km_path, device):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np).to(device)
self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device)
......
......@@ -206,7 +206,7 @@ def adjust_learning_rate(epoch, optimizer, learning_rate, anneal_steps, anneal_f
if anneal_factor == 0.3:
lr = learning_rate * ((0.1 ** (p // 2)) * (1.0 if p % 2 == 0 else 0.3))
else:
lr = learning_rate * (anneal_factor ** p)
lr = learning_rate * (anneal_factor**p)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
......
......@@ -99,7 +99,7 @@ def collate_factory(args):
if args.loss == "crossentropy":
if args.mulaw:
mulaw_encode = MuLawEncoding(2 ** args.n_bits)
mulaw_encode = MuLawEncoding(2**args.n_bits)
waveform = mulaw_encode(waveform)
target = mulaw_encode(target)
......
......@@ -332,7 +332,7 @@ def main(args):
**loader_validation_params,
)
n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30
n_classes = 2**args.n_bits if args.loss == "crossentropy" else 30
model = WaveRNN(
upsample_scales=args.upsample_scales,
......
......@@ -21,11 +21,11 @@ def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tens
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
assert abs(waveform).max() <= 1.0
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
return torch.clamp(waveform, 0, 2 ** bits - 1).int()
waveform = (waveform + 1.0) * (2**bits - 1) / 2
return torch.clamp(waveform, 0, 2**bits - 1).int()
def bits_to_normalized_waveform(label: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1]"""
return 2 * label / (2 ** bits - 1.0) - 1.0
return 2 * label / (2**bits - 1.0) - 1.0
......@@ -71,7 +71,7 @@ def _get_freq_ticks(sample_rate, offset, f_max):
time, freq = [], []
for exp in range(2, 5):
for v in range(1, 10):
f = v * 10 ** exp
f = v * 10**exp
if f < sample_rate // 2:
t = _get_inverse_log_freq(f, sample_rate, offset) / sample_rate
time.append(t)
......
......@@ -16,9 +16,9 @@ class Tester(common_utils.TorchaudioTestCase):
volume = 0.3
waveform = torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)
waveform.unsqueeze_(0) # (1, 64000)
waveform = (waveform * volume * 2 ** 31).long()
waveform = (waveform * volume * 2**31).long()
def scale(self, waveform, factor=2.0 ** 31):
def scale(self, waveform, factor=2.0**31):
# scales a waveform by a factor
if not waveform.is_floating_point():
waveform = waveform.to(torch.get_default_dtype())
......
......@@ -126,7 +126,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2)
chunk = file_obj.read(1024**2)
if not chunk:
break
hash_func.update(chunk)
......
......@@ -1298,7 +1298,7 @@ def riaa_biquad(waveform: Tensor, sample_rate: int) -> Tensor:
a_re = a0 + a1 * math.cos(-y) + a2 * math.cos(-2 * y)
b_im = b1 * math.sin(-y) + b2 * math.sin(-2 * y)
a_im = a1 * math.sin(-y) + a2 * math.sin(-2 * y)
g = 1 / math.sqrt((b_re ** 2 + b_im ** 2) / (a_re ** 2 + a_im ** 2))
g = 1 / math.sqrt((b_re**2 + b_im**2) / (a_re**2 + a_im**2))
b0 *= g
b1 *= g
......
......@@ -1150,18 +1150,18 @@ def sliding_window_cmn(
input_part = specgram[:, window_start : window_end - window_start, :]
cur_sum += torch.sum(input_part, 1)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
cur_sumsq += torch.cumsum(input_part**2, 1)[:, -1, :]
else:
if window_start > last_window_start:
frame_to_remove = specgram[:, last_window_start, :]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= frame_to_remove ** 2
cur_sumsq -= frame_to_remove**2
if window_end > last_window_end:
frame_to_add = specgram[:, last_window_end, :]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += frame_to_add ** 2
cur_sumsq += frame_to_add**2
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
......@@ -1172,7 +1172,7 @@ def sliding_window_cmn(
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= (cur_sum ** 2) / (window_frames ** 2)
variance -= (cur_sum**2) / (window_frames**2)
variance = torch.pow(variance, -0.5)
cmn_specgram[:, t, :] *= variance
......
......@@ -109,7 +109,7 @@ class MaskGenerator(torch.nn.Module):
self.conv_layers = torch.nn.ModuleList([])
for s in range(num_stacks):
for l in range(num_layers):
multi = 2 ** l
multi = 2**l
self.conv_layers.append(
ConvBlock(
io_channels=num_feats,
......
......@@ -231,7 +231,7 @@ class SelfAttention(Module):
self.dropout = torch.nn.Dropout(dropout)
self.head_dim = head_dim
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
......
......@@ -396,7 +396,7 @@ class WaveRNN(nn.Module):
x = torch.multinomial(posterior, 1).float()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x = 2 * x / (2 ** self.n_bits - 1.0) - 1.0
x = 2 * x / (2**self.n_bits - 1.0) - 1.0
output.append(x)
......
......@@ -185,8 +185,8 @@ def _load_phonemizer(file, dl_kwargs):
def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor:
r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]"""
waveform = torch.clamp(waveform, -1, 1)
waveform = (waveform + 1.0) * (2 ** bits - 1) / 2
return torch.clamp(waveform, 0, 2 ** bits - 1).int()
waveform = (waveform + 1.0) * (2**bits - 1) / 2
return torch.clamp(waveform, 0, 2**bits - 1).int()
def _get_taco_params(n_symbols):
......@@ -219,7 +219,7 @@ def _get_taco_params(n_symbols):
def _get_wrnn_params():
return {
"upsample_scales": [5, 5, 11],
"n_classes": 2 ** 8, # n_bits = 8
"n_classes": 2**8, # n_bits = 8
"hop_length": 275,
"n_res_block": 10,
"n_rnn": 512,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment