"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "d35ade751ea66cd3b44e978470f3ce5b9691fb2e"
Commit e1623106 authored by John Reese's avatar John Reese Committed by Facebook GitHub Bot
Browse files

formatting changes from black 22.3.0

Summary:
Applies the black-fbsource codemod with the new build of pyfmt.

paintitblack

Reviewed By: lisroach

Differential Revision: D36324783

fbshipit-source-id: 280c09e88257e5e569ab729691165d8dedd767bc
parent 5b09d662
...@@ -63,7 +63,7 @@ class DiskCachedDatasetFromList(data.Dataset): ...@@ -63,7 +63,7 @@ class DiskCachedDatasetFromList(data.Dataset):
total_size = sum(len(x) for x in self._lst) total_size = sum(len(x) for x in self._lst)
# TODO: only enabling DiskCachedDataset for large enough dataset # TODO: only enabling DiskCachedDataset for large enough dataset
logger.info( logger.info(
"Serialized dataset takes {:.2f} MiB".format(total_size / 1024 ** 2) "Serialized dataset takes {:.2f} MiB".format(total_size / 1024**2)
) )
self._initialize_diskcache() self._initialize_diskcache()
...@@ -128,7 +128,7 @@ class DiskCachedDatasetFromList(data.Dataset): ...@@ -128,7 +128,7 @@ class DiskCachedDatasetFromList(data.Dataset):
comm.synchronize() comm.synchronize()
logger.info( logger.info(
"Finished writing to local disk, db size: {:.2f} MiB".format( "Finished writing to local disk, db size: {:.2f} MiB".format(
self._cache.cache.volume() / 1024 ** 2 self._cache.cache.volume() / 1024**2
) )
) )
# Optional sync for some strategies # Optional sync for some strategies
...@@ -158,7 +158,7 @@ class DiskCachedDatasetFromList(data.Dataset): ...@@ -158,7 +158,7 @@ class DiskCachedDatasetFromList(data.Dataset):
"({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format( "({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format(
percentage, percentage,
len(self._cache.cache), len(self._cache.cache),
self._cache.cache.volume() / 1024 ** 2, self._cache.cache.volume() / 1024**2,
), ),
n=10, n=10,
) )
......
...@@ -8,7 +8,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint ...@@ -8,7 +8,7 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
def get_lt_trainer(output_dir: str, cfg): def get_lt_trainer(output_dir: str, cfg):
checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True) checkpoint_callback = ModelCheckpoint(dirpath=output_dir, save_last=True)
return pl.Trainer( return pl.Trainer(
max_epochs=10 ** 8, max_epochs=10**8,
max_steps=cfg.SOLVER.MAX_ITER, max_steps=cfg.SOLVER.MAX_ITER,
val_check_interval=cfg.TEST.EVAL_PERIOD val_check_interval=cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0 if cfg.TEST.EVAL_PERIOD > 0
......
...@@ -159,7 +159,7 @@ class DeformableTransformer(nn.Module): ...@@ -159,7 +159,7 @@ class DeformableTransformer(nn.Module):
# grid shape (bs, H_l, W_l, 2). Value could be > 1 # grid shape (bs, H_l, W_l, 2). Value could be > 1
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
# wh shape (bs, H_l, W_l, 2) # wh shape (bs, H_l, W_l, 2)
wh = torch.ones_like(grid) * base_object_scale * (2.0 ** lvl) wh = torch.ones_like(grid) * base_object_scale * (2.0**lvl)
# proposal shape (bs, H_l * W_l, 4) # proposal shape (bs, H_l * W_l, 4)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal) proposals.append(proposal)
......
...@@ -84,7 +84,7 @@ class HungarianMatcher(nn.Module): ...@@ -84,7 +84,7 @@ class HungarianMatcher(nn.Module):
alpha = 0.25 alpha = 0.25
gamma = 2.0 gamma = 2.0
neg_cost_class = ( neg_cost_class = (
(1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
) )
pos_cost_class = ( pos_cost_class = (
alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
......
...@@ -204,7 +204,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -204,7 +204,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
test_out = model.eval()(test_in) test_out = model.eval()(test_in)
self.assertGreater( self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03." (test_out**2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
) )
base_out = qat.quantized.eval()(test_in) base_out = qat.quantized.eval()(test_in)
...@@ -330,7 +330,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -330,7 +330,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
test_out = model.eval()(test_in) test_out = model.eval()(test_in)
self.assertGreater( self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03." (test_out**2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
) )
base_out = qat.quantized.eval()(test_in) base_out = qat.quantized.eval()(test_in)
...@@ -369,7 +369,7 @@ class TestQuantizationAwareTraining(unittest.TestCase): ...@@ -369,7 +369,7 @@ class TestQuantizationAwareTraining(unittest.TestCase):
test_out = model.eval()(test_in) test_out = model.eval()(test_in)
self.assertGreater( self.assertGreater(
(test_out ** 2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03." (test_out**2).sum(), 0.03, "With the given seend, L2^2 should be > 0.03."
) )
base_out = qat.quantized.eval()(test_in) base_out = qat.quantized.eval()(test_in)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment