Unverified Commit d367a01a authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files

Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 50dfe207
...@@ -34,7 +34,6 @@ A diff output is produced and a sensible exit code is returned. ...@@ -34,7 +34,6 @@ A diff output is produced and a sensible exit code is returned.
import argparse import argparse
import difflib import difflib
import fnmatch import fnmatch
import io
import multiprocessing import multiprocessing
import os import os
import signal import signal
...@@ -87,20 +86,20 @@ def list_files(files, recursive=False, extensions=None, exclude=None): ...@@ -87,20 +86,20 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def make_diff(file, original, reformatted): def make_diff(file, original, reformatted):
return list( return list(
difflib.unified_diff( difflib.unified_diff(
original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3 original, reformatted, fromfile=f"{file}\t(original)", tofile=f"{file}\t(reformatted)", n=3
) )
) )
class DiffError(Exception): class DiffError(Exception):
def __init__(self, message, errs=None): def __init__(self, message, errs=None):
super(DiffError, self).__init__(message) super().__init__(message)
self.errs = errs or [] self.errs = errs or []
class UnexpectedError(Exception): class UnexpectedError(Exception):
def __init__(self, message, exc=None): def __init__(self, message, exc=None):
super(UnexpectedError, self).__init__(message) super().__init__(message)
self.formatted_traceback = traceback.format_exc() self.formatted_traceback = traceback.format_exc()
self.exc = exc self.exc = exc
...@@ -112,14 +111,14 @@ def run_clang_format_diff_wrapper(args, file): ...@@ -112,14 +111,14 @@ def run_clang_format_diff_wrapper(args, file):
except DiffError: except DiffError:
raise raise
except Exception as e: except Exception as e:
raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e) raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e)
def run_clang_format_diff(args, file): def run_clang_format_diff(args, file):
try: try:
with io.open(file, "r", encoding="utf-8") as f: with open(file, encoding="utf-8") as f:
original = f.readlines() original = f.readlines()
except IOError as exc: except OSError as exc:
raise DiffError(str(exc)) raise DiffError(str(exc))
invocation = [args.clang_format_executable, file] invocation = [args.clang_format_executable, file]
...@@ -145,7 +144,7 @@ def run_clang_format_diff(args, file): ...@@ -145,7 +144,7 @@ def run_clang_format_diff(args, file):
invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8" invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8"
) )
except OSError as exc: except OSError as exc:
raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc)) raise DiffError(f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}")
proc_stdout = proc.stdout proc_stdout = proc.stdout
proc_stderr = proc.stderr proc_stderr = proc.stderr
...@@ -203,7 +202,7 @@ def print_trouble(prog, message, use_colors): ...@@ -203,7 +202,7 @@ def print_trouble(prog, message, use_colors):
error_text = "error:" error_text = "error:"
if use_colors: if use_colors:
error_text = bold_red(error_text) error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr) print(f"{prog}: {error_text} {message}", file=sys.stderr)
def main(): def main():
...@@ -216,7 +215,7 @@ def main(): ...@@ -216,7 +215,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--extensions", "--extensions",
help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS), help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})",
default=DEFAULT_EXTENSIONS, default=DEFAULT_EXTENSIONS,
) )
parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories") parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories")
...@@ -227,7 +226,7 @@ def main(): ...@@ -227,7 +226,7 @@ def main():
metavar="N", metavar="N",
type=int, type=int,
default=0, default=0,
help="run N clang-format jobs in parallel" " (default number of cpus + 1)", help="run N clang-format jobs in parallel (default number of cpus + 1)",
) )
parser.add_argument( parser.add_argument(
"--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)" "--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)"
...@@ -238,7 +237,7 @@ def main(): ...@@ -238,7 +237,7 @@ def main():
metavar="PATTERN", metavar="PATTERN",
action="append", action="append",
default=[], default=[],
help="exclude paths matching the given glob-like pattern(s)" " from recursive search", help="exclude paths matching the given glob-like pattern(s) from recursive search",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -263,7 +262,7 @@ def main(): ...@@ -263,7 +262,7 @@ def main():
colored_stdout = sys.stdout.isatty() colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty() colored_stderr = sys.stderr.isatty()
version_invocation = [args.clang_format_executable, str("--version")] version_invocation = [args.clang_format_executable, "--version"]
try: try:
subprocess.check_call(version_invocation, stdout=DEVNULL) subprocess.check_call(version_invocation, stdout=DEVNULL)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
...@@ -272,7 +271,7 @@ def main(): ...@@ -272,7 +271,7 @@ def main():
except OSError as e: except OSError as e:
print_trouble( print_trouble(
parser.prog, parser.prog,
"Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e), f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}",
use_colors=colored_stderr, use_colors=colored_stderr,
) )
return ExitStatus.TROUBLE return ExitStatus.TROUBLE
......
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-docstring-first
- id: check-toml
- id: check-yaml
exclude: packaging/.*
- id: end-of-file-fixer
# - repo: https://github.com/asottile/pyupgrade
# rev: v2.29.0
# hooks:
# - id: pyupgrade
# args: [--py36-plus]
# name: Upgrade code
- repo: https://github.com/omnilib/ufmt - repo: https://github.com/omnilib/ufmt
rev: v1.3.0 rev: v1.3.0
hooks: hooks:
...@@ -6,16 +22,9 @@ repos: ...@@ -6,16 +22,9 @@ repos:
additional_dependencies: additional_dependencies:
- black == 21.9b0 - black == 21.9b0
- usort == 0.6.4 - usort == 0.6.4
- repo: https://gitlab.com/pycqa/flake8 - repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2 rev: 3.9.2
hooks: hooks:
- id: flake8 - id: flake8
args: [--config=setup.cfg] args: [--config=setup.cfg]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-docstring-first
- id: check-toml
- id: check-yaml
exclude: packaging/.*
- id: end-of-file-fixer
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
# #
# PyTorch documentation build configuration file, created by # PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016. # sphinx-quickstart on Fri Dec 23 13:31:47 2016.
......
...@@ -125,7 +125,7 @@ res_scripted = scripted_predictor(batch) ...@@ -125,7 +125,7 @@ res_scripted = scripted_predictor(batch)
import json import json
with open(Path('assets') / 'imagenet_class_index.json', 'r') as labels_file: with open(Path('assets') / 'imagenet_class_index.json') as labels_file:
labels = json.load(labels_file) labels = json.load(labels_file)
for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)): for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
......
...@@ -137,7 +137,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au ...@@ -137,7 +137,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
if end < start: if end < start:
raise ValueError( raise ValueError(
"end time should be larger than start time, got " "end time should be larger than start time, got "
"start time={} and end time={}".format(start, end) f"start time={start} and end time={end}"
) )
video_frames = torch.empty(0) video_frames = torch.empty(0)
......
# -*- coding: utf-8 -*-
"""Helper script to package wheels and relocate binaries.""" """Helper script to package wheels and relocate binaries."""
import glob import glob
...@@ -157,7 +155,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -157,7 +155,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
rename and copy them into the wheel while updating their respective rpaths. rename and copy them into the wheel while updating their respective rpaths.
""" """
print("Relocating {0}".format(binary)) print(f"Relocating {binary}")
binary_path = osp.join(output_library, binary) binary_path = osp.join(output_library, binary)
ld_tree = lddtree(binary_path) ld_tree = lddtree(binary_path)
...@@ -173,12 +171,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -173,12 +171,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
print(library) print(library)
if library_info["path"] is None: if library_info["path"] is None:
print("Omitting {0}".format(library)) print(f"Omitting {library}")
continue continue
if library in ALLOWLIST: if library in ALLOWLIST:
# Omit glibc/gcc/system libraries # Omit glibc/gcc/system libraries
print("Omitting {0}".format(library)) print(f"Omitting {library}")
continue continue
parent_dependencies = binary_dependencies.get(parent, []) parent_dependencies = binary_dependencies.get(parent, [])
...@@ -201,7 +199,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -201,7 +199,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if library != binary: if library != binary:
library_path = binary_paths[library] library_path = binary_paths[library]
new_library_path = patch_new_path(library_path, new_libraries_path) new_library_path = patch_new_path(library_path, new_libraries_path)
print("{0} -> {1}".format(library, new_library_path)) print(f"{library} -> {new_library_path}")
shutil.copyfile(library_path, new_library_path) shutil.copyfile(library_path, new_library_path)
new_names[library] = new_library_path new_names[library] = new_library_path
...@@ -214,7 +212,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -214,7 +212,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
new_library_name = new_names[library] new_library_name = new_names[library]
for dep in library_dependencies: for dep in library_dependencies:
new_dep = osp.basename(new_names[dep]) new_dep = osp.basename(new_names[dep])
print("{0}: {1} -> {2}".format(library, dep, new_dep)) print(f"{library}: {dep} -> {new_dep}")
subprocess.check_output( subprocess.check_output(
[patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path [patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path
) )
...@@ -228,7 +226,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -228,7 +226,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
library_dependencies = binary_dependencies[binary] library_dependencies = binary_dependencies[binary]
for dep in library_dependencies: for dep in library_dependencies:
new_dep = osp.basename(new_names[dep]) new_dep = osp.basename(new_names[dep])
print("{0}: {1} -> {2}".format(binary, dep, new_dep)) print(f"{binary}: {dep} -> {new_dep}")
subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library) subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library)
print("Update library rpath") print("Update library rpath")
...@@ -244,7 +242,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): ...@@ -244,7 +242,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
Given a shared library, find the transitive closure of its dependencies, Given a shared library, find the transitive closure of its dependencies,
rename and copy them into the wheel. rename and copy them into the wheel.
""" """
print("Relocating {0}".format(binary)) print(f"Relocating {binary}")
binary_path = osp.join(output_library, binary) binary_path = osp.join(output_library, binary)
library_dlls = find_dll_dependencies(dumpbin, binary_path) library_dlls = find_dll_dependencies(dumpbin, binary_path)
...@@ -255,18 +253,18 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): ...@@ -255,18 +253,18 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
while binary_queue != []: while binary_queue != []:
library, parent = binary_queue.pop(0) library, parent = binary_queue.pop(0)
if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"): if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"):
print("Omitting {0}".format(library)) print(f"Omitting {library}")
continue continue
library_path = find_program(library) library_path = find_program(library)
if library_path is None: if library_path is None:
print("{0} not found".format(library)) print(f"{library} not found")
continue continue
if osp.basename(osp.dirname(library_path)) == "system32": if osp.basename(osp.dirname(library_path)) == "system32":
continue continue
print("{0}: {1}".format(library, library_path)) print(f"{library}: {library_path}")
parent_dependencies = binary_dependencies.get(parent, []) parent_dependencies = binary_dependencies.get(parent, [])
parent_dependencies.append(library) parent_dependencies.append(library)
binary_dependencies[parent] = parent_dependencies binary_dependencies[parent] = parent_dependencies
...@@ -284,7 +282,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): ...@@ -284,7 +282,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
if library != binary: if library != binary:
library_path = binary_paths[library] library_path = binary_paths[library]
new_library_path = osp.join(package_dir, library) new_library_path = osp.join(package_dir, library)
print("{0} -> {1}".format(library, new_library_path)) print(f"{library} -> {new_library_path}")
shutil.copyfile(library_path, new_library_path) shutil.copyfile(library_path, new_library_path)
...@@ -300,16 +298,16 @@ def compress_wheel(output_dir, wheel, wheel_dir, wheel_name): ...@@ -300,16 +298,16 @@ def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
full_file = osp.join(root, this_file) full_file = osp.join(root, this_file)
rel_file = osp.relpath(full_file, output_dir) rel_file = osp.relpath(full_file, output_dir)
if full_file == record_file: if full_file == record_file:
f.write("{0},,\n".format(rel_file)) f.write(f"{rel_file},,\n")
else: else:
digest, size = rehash(full_file) digest, size = rehash(full_file)
f.write("{0},{1},{2}\n".format(rel_file, digest, size)) f.write(f"{rel_file},{digest},{size}\n")
print("Compressing wheel") print("Compressing wheel")
base_wheel_name = osp.join(wheel_dir, wheel_name) base_wheel_name = osp.join(wheel_dir, wheel_name)
shutil.make_archive(base_wheel_name, "zip", output_dir) shutil.make_archive(base_wheel_name, "zip", output_dir)
os.remove(wheel) os.remove(wheel)
shutil.move("{0}.zip".format(base_wheel_name), wheel) shutil.move(f"{base_wheel_name}.zip", wheel)
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
...@@ -317,9 +315,7 @@ def patch_linux(): ...@@ -317,9 +315,7 @@ def patch_linux():
# Get patchelf location # Get patchelf location
patchelf = find_program("patchelf") patchelf = find_program("patchelf")
if patchelf is None: if patchelf is None:
raise FileNotFoundError( raise FileNotFoundError("Patchelf was not found in the system, please make sure that is available on the PATH.")
"Patchelf was not found in the system, please" " make sure that is available on the PATH."
)
# Find wheel # Find wheel
print("Finding wheels...") print("Finding wheels...")
...@@ -338,7 +334,7 @@ def patch_linux(): ...@@ -338,7 +334,7 @@ def patch_linux():
print("Unzipping wheel...") print("Unzipping wheel...")
wheel_file = osp.basename(wheel) wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel) wheel_dir = osp.dirname(wheel)
print("{0}".format(wheel_file)) print(f"{wheel_file}")
wheel_name, _ = osp.splitext(wheel_file) wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir) unzip_file(wheel, output_dir)
...@@ -355,9 +351,7 @@ def patch_win(): ...@@ -355,9 +351,7 @@ def patch_win():
# Get dumpbin location # Get dumpbin location
dumpbin = find_program("dumpbin") dumpbin = find_program("dumpbin")
if dumpbin is None: if dumpbin is None:
raise FileNotFoundError( raise FileNotFoundError("Dumpbin was not found in the system, please make sure that is available on the PATH.")
"Dumpbin was not found in the system, please" " make sure that is available on the PATH."
)
# Find wheel # Find wheel
print("Finding wheels...") print("Finding wheels...")
...@@ -376,7 +370,7 @@ def patch_win(): ...@@ -376,7 +370,7 @@ def patch_win():
print("Unzipping wheel...") print("Unzipping wheel...")
wheel_file = osp.basename(wheel) wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel) wheel_dir = osp.dirname(wheel)
print("{0}".format(wheel_file)) print(f"{wheel_file}")
wheel_name, _ = osp.splitext(wheel_file) wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir) unzip_file(wheel, output_dir)
......
...@@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg ...@@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = "Epoch: [{}]".format(epoch) header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time() start_time = time.time()
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
...@@ -121,7 +121,7 @@ def load_data(traindir, valdir, args): ...@@ -121,7 +121,7 @@ def load_data(traindir, valdir, args):
cache_path = _get_cache_path(traindir) cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached! # Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path)) print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path) dataset, _ = torch.load(cache_path)
else: else:
auto_augment_policy = getattr(args, "auto_augment", None) auto_augment_policy = getattr(args, "auto_augment", None)
...@@ -136,7 +136,7 @@ def load_data(traindir, valdir, args): ...@@ -136,7 +136,7 @@ def load_data(traindir, valdir, args):
), ),
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path)) print(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path) utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st) print("Took", time.time() - st)
...@@ -145,7 +145,7 @@ def load_data(traindir, valdir, args): ...@@ -145,7 +145,7 @@ def load_data(traindir, valdir, args):
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached! # Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path)) print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path) dataset_test, _ = torch.load(cache_path)
else: else:
if not args.weights: if not args.weights:
...@@ -162,7 +162,7 @@ def load_data(traindir, valdir, args): ...@@ -162,7 +162,7 @@ def load_data(traindir, valdir, args):
preprocessing, preprocessing,
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path) utils.save_on_master((dataset_test, valdir), cache_path)
...@@ -270,8 +270,8 @@ def main(args): ...@@ -270,8 +270,8 @@ def main(args):
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR " f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported.".format(args.lr_scheduler) "are supported."
) )
if args.lr_warmup_epochs > 0: if args.lr_warmup_epochs > 0:
...@@ -285,7 +285,7 @@ def main(args): ...@@ -285,7 +285,7 @@ def main(args):
) )
else: else:
raise RuntimeError( raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported." f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
) )
lr_scheduler = torch.optim.lr_scheduler.SequentialLR( lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
...@@ -351,12 +351,12 @@ def main(args): ...@@ -351,12 +351,12 @@ def main(args):
} }
if model_ema: if model_ema:
checkpoint["model_ema"] = model_ema.state_dict() checkpoint["model_ema"] = model_ema.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str)) print(f"Training time {total_time_str}")
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
......
...@@ -20,7 +20,7 @@ def main(args): ...@@ -20,7 +20,7 @@ def main(args):
print(args) print(args)
if args.post_training_quantize and args.distributed: if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed " "on distributed mode") raise RuntimeError("Post training quantization example should not be performed on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels # Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines: if args.backend not in torch.backends.quantized.supported_engines:
...@@ -141,13 +141,13 @@ def main(args): ...@@ -141,13 +141,13 @@ def main(args):
"epoch": epoch, "epoch": epoch,
"args": args, "args": args,
} }
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
print("Saving models after epoch ", epoch) print("Saving models after epoch ", epoch)
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str)) print(f"Training time {total_time_str}")
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
......
...@@ -39,13 +39,13 @@ class RandomMixup(torch.nn.Module): ...@@ -39,13 +39,13 @@ class RandomMixup(torch.nn.Module):
Tensor: Randomly transformed batch. Tensor: Randomly transformed batch.
""" """
if batch.ndim != 4: if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
elif target.ndim != 1: if target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
elif not batch.is_floating_point(): if not batch.is_floating_point():
raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype)) raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
elif target.dtype != torch.int64: if target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace: if not self.inplace:
batch = batch.clone() batch = batch.clone()
...@@ -115,13 +115,13 @@ class RandomCutmix(torch.nn.Module): ...@@ -115,13 +115,13 @@ class RandomCutmix(torch.nn.Module):
Tensor: Randomly transformed batch. Tensor: Randomly transformed batch.
""" """
if batch.ndim != 4: if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
elif target.ndim != 1: if target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
elif not batch.is_floating_point(): if not batch.is_floating_point():
raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype)) raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
elif target.dtype != torch.int64: if target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace: if not self.inplace:
batch = batch.clone() batch = batch.clone()
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
class SmoothedValue(object): class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
""" """
...@@ -65,7 +65,7 @@ class SmoothedValue(object): ...@@ -65,7 +65,7 @@ class SmoothedValue(object):
) )
class MetricLogger(object): class MetricLogger:
def __init__(self, delimiter="\t"): def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue) self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter self.delimiter = delimiter
...@@ -82,12 +82,12 @@ class MetricLogger(object): ...@@ -82,12 +82,12 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter))) loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
...@@ -152,7 +152,7 @@ class MetricLogger(object): ...@@ -152,7 +152,7 @@ class MetricLogger(object):
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {}".format(header, total_time_str)) print(f"{header} Total time: {total_time_str}")
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
...@@ -270,7 +270,7 @@ def init_distributed_mode(args): ...@@ -270,7 +270,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl" args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
) )
...@@ -307,8 +307,7 @@ def average_checkpoints(inputs): ...@@ -307,8 +307,7 @@ def average_checkpoints(inputs):
params_keys = model_params_keys params_keys = model_params_keys
elif params_keys != model_params_keys: elif params_keys != model_params_keys:
raise KeyError( raise KeyError(
"For checkpoint {}, expected list of params: {}, " f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
"but found: {}".format(f, params_keys, model_params_keys)
) )
for k in params_keys: for k in params_keys:
p = model_params[k] p = model_params[k]
......
...@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask ...@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from pycocotools.coco import COCO from pycocotools.coco import COCO
class FilterAndRemapCocoCategories(object): class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True): def __init__(self, categories, remap=True):
self.categories = categories self.categories = categories
self.remap = remap self.remap = remap
...@@ -44,7 +44,7 @@ def convert_coco_poly_to_mask(segmentations, height, width): ...@@ -44,7 +44,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks return masks
class ConvertCocoPolysToMask(object): class ConvertCocoPolysToMask:
def __call__(self, image, target): def __call__(self, image, target):
w, h = image.size w, h = image.size
...@@ -205,11 +205,11 @@ def get_coco_api_from_dataset(dataset): ...@@ -205,11 +205,11 @@ def get_coco_api_from_dataset(dataset):
class CocoDetection(torchvision.datasets.CocoDetection): class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms): def __init__(self, img_folder, ann_file, transforms):
super(CocoDetection, self).__init__(img_folder, ann_file) super().__init__(img_folder, ann_file)
self._transforms = transforms self._transforms = transforms
def __getitem__(self, idx): def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx) img, target = super().__getitem__(idx)
image_id = self.ids[idx] image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target) target = dict(image_id=image_id, annotations=target)
if self._transforms is not None: if self._transforms is not None:
......
...@@ -13,7 +13,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): ...@@ -13,7 +13,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch) header = f"Epoch: [{epoch}]"
lr_scheduler = None lr_scheduler = None
if epoch == 0: if epoch == 0:
...@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): ...@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
loss_value = losses_reduced.item() loss_value = losses_reduced.item()
if not math.isfinite(loss_value): if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value)) print(f"Loss is {loss_value}, stopping training")
print(loss_dict_reduced) print(loss_dict_reduced)
sys.exit(1) sys.exit(1)
......
...@@ -36,9 +36,7 @@ class GroupedBatchSampler(BatchSampler): ...@@ -36,9 +36,7 @@ class GroupedBatchSampler(BatchSampler):
def __init__(self, sampler, group_ids, batch_size): def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError( raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler self.sampler = sampler
self.group_ids = group_ids self.group_ids = group_ids
self.batch_size = batch_size self.batch_size = batch_size
...@@ -193,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0): ...@@ -193,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
# count number of elements per group # count number of elements per group
counts = np.unique(groups, return_counts=True)[1] counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf] fbins = [0] + bins + [np.inf]
print("Using {} as bins for aspect ratio quantization".format(fbins)) print(f"Using {fbins} as bins for aspect ratio quantization")
print("Count of instances per bin: {}".format(counts)) print(f"Count of instances per bin: {counts}")
return groups return groups
...@@ -65,7 +65,7 @@ def get_args_parser(add_help=True): ...@@ -65,7 +65,7 @@ def get_args_parser(add_help=True):
"--lr", "--lr",
default=0.02, default=0.02,
type=float, type=float,
help="initial learning rate, 0.02 is the default value for training " "on 8 gpus and 2 images_per_gpu", help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
) )
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument( parser.add_argument(
...@@ -197,8 +197,7 @@ def main(args): ...@@ -197,8 +197,7 @@ def main(args):
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
"are supported.".format(args.lr_scheduler)
) )
if args.resume: if args.resume:
...@@ -227,7 +226,7 @@ def main(args): ...@@ -227,7 +226,7 @@ def main(args):
"args": args, "args": args,
"epoch": epoch, "epoch": epoch,
} }
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
# evaluate after every epoch # evaluate after every epoch
...@@ -235,7 +234,7 @@ def main(args): ...@@ -235,7 +234,7 @@ def main(args):
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str)) print(f"Training time {total_time_str}")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width): ...@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
return flipped_data return flipped_data
class Compose(object): class Compose:
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
...@@ -103,7 +103,7 @@ class RandomIoUCrop(nn.Module): ...@@ -103,7 +103,7 @@ class RandomIoUCrop(nn.Module):
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
...@@ -171,7 +171,7 @@ class RandomZoomOut(nn.Module): ...@@ -171,7 +171,7 @@ class RandomZoomOut(nn.Module):
self.fill = fill self.fill = fill
self.side_range = side_range self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]: if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range)) raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p self.p = p
@torch.jit.unused @torch.jit.unused
...@@ -185,7 +185,7 @@ class RandomZoomOut(nn.Module): ...@@ -185,7 +185,7 @@ class RandomZoomOut(nn.Module):
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
...@@ -244,7 +244,7 @@ class RandomPhotometricDistort(nn.Module): ...@@ -244,7 +244,7 @@ class RandomPhotometricDistort(nn.Module):
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
class SmoothedValue(object): class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
""" """
...@@ -110,7 +110,7 @@ def reduce_dict(input_dict, average=True): ...@@ -110,7 +110,7 @@ def reduce_dict(input_dict, average=True):
return reduced_dict return reduced_dict
class MetricLogger(object): class MetricLogger:
def __init__(self, delimiter="\t"): def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue) self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter self.delimiter = delimiter
...@@ -127,12 +127,12 @@ class MetricLogger(object): ...@@ -127,12 +127,12 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter))) loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
...@@ -197,7 +197,7 @@ class MetricLogger(object): ...@@ -197,7 +197,7 @@ class MetricLogger(object):
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
def collate_fn(batch): def collate_fn(batch):
...@@ -274,7 +274,7 @@ def init_distributed_mode(args): ...@@ -274,7 +274,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl" args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
) )
......
...@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask ...@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from transforms import Compose from transforms import Compose
class FilterAndRemapCocoCategories(object): class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True): def __init__(self, categories, remap=True):
self.categories = categories self.categories = categories
self.remap = remap self.remap = remap
...@@ -41,7 +41,7 @@ def convert_coco_poly_to_mask(segmentations, height, width): ...@@ -41,7 +41,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks return masks
class ConvertCocoPolysToMask(object): class ConvertCocoPolysToMask:
def __call__(self, image, anno): def __call__(self, image, anno):
w, h = image.size w, h = image.size
segmentations = [obj["segmentation"] for obj in anno] segmentations = [obj["segmentation"] for obj in anno]
......
...@@ -66,7 +66,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi ...@@ -66,7 +66,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
header = "Epoch: [{}]".format(epoch) header = f"Epoch: [{epoch}]"
for image, target in metric_logger.log_every(data_loader, print_freq, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
output = model(image) output = model(image)
...@@ -152,8 +152,7 @@ def main(args): ...@@ -152,8 +152,7 @@ def main(args):
) )
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid warmup lr method '{}'. Only linear and constant " f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
"are supported.".format(args.lr_warmup_method)
) )
lr_scheduler = torch.optim.lr_scheduler.SequentialLR( lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters] optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
...@@ -188,12 +187,12 @@ def main(args): ...@@ -188,12 +187,12 @@ def main(args):
"epoch": epoch, "epoch": epoch,
"args": args, "args": args,
} }
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str)) print(f"Training time {total_time_str}")
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
......
...@@ -16,7 +16,7 @@ def pad_if_smaller(img, size, fill=0): ...@@ -16,7 +16,7 @@ def pad_if_smaller(img, size, fill=0):
return img return img
class Compose(object): class Compose:
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
...@@ -26,7 +26,7 @@ class Compose(object): ...@@ -26,7 +26,7 @@ class Compose(object):
return image, target return image, target
class RandomResize(object): class RandomResize:
def __init__(self, min_size, max_size=None): def __init__(self, min_size, max_size=None):
self.min_size = min_size self.min_size = min_size
if max_size is None: if max_size is None:
...@@ -40,7 +40,7 @@ class RandomResize(object): ...@@ -40,7 +40,7 @@ class RandomResize(object):
return image, target return image, target
class RandomHorizontalFlip(object): class RandomHorizontalFlip:
def __init__(self, flip_prob): def __init__(self, flip_prob):
self.flip_prob = flip_prob self.flip_prob = flip_prob
...@@ -51,7 +51,7 @@ class RandomHorizontalFlip(object): ...@@ -51,7 +51,7 @@ class RandomHorizontalFlip(object):
return image, target return image, target
class RandomCrop(object): class RandomCrop:
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
...@@ -64,7 +64,7 @@ class RandomCrop(object): ...@@ -64,7 +64,7 @@ class RandomCrop(object):
return image, target return image, target
class CenterCrop(object): class CenterCrop:
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
...@@ -90,7 +90,7 @@ class ConvertImageDtype: ...@@ -90,7 +90,7 @@ class ConvertImageDtype:
return image, target return image, target
class Normalize(object): class Normalize:
def __init__(self, mean, std): def __init__(self, mean, std):
self.mean = mean self.mean = mean
self.std = std self.std = std
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
class SmoothedValue(object): class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
""" """
...@@ -67,7 +67,7 @@ class SmoothedValue(object): ...@@ -67,7 +67,7 @@ class SmoothedValue(object):
) )
class ConfusionMatrix(object): class ConfusionMatrix:
def __init__(self, num_classes): def __init__(self, num_classes):
self.num_classes = num_classes self.num_classes = num_classes
self.mat = None self.mat = None
...@@ -101,15 +101,15 @@ class ConfusionMatrix(object): ...@@ -101,15 +101,15 @@ class ConfusionMatrix(object):
def __str__(self): def __str__(self):
acc_global, acc, iu = self.compute() acc_global, acc, iu = self.compute()
return ("global correct: {:.1f}\n" "average row correct: {}\n" "IoU: {}\n" "mean IoU: {:.1f}").format( return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
acc_global.item() * 100, acc_global.item() * 100,
["{:.1f}".format(i) for i in (acc * 100).tolist()], [f"{i:.1f}" for i in (acc * 100).tolist()],
["{:.1f}".format(i) for i in (iu * 100).tolist()], [f"{i:.1f}" for i in (iu * 100).tolist()],
iu.mean().item() * 100, iu.mean().item() * 100,
) )
class MetricLogger(object): class MetricLogger:
def __init__(self, delimiter="\t"): def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue) self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter self.delimiter = delimiter
...@@ -126,12 +126,12 @@ class MetricLogger(object): ...@@ -126,12 +126,12 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter))) loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
...@@ -196,7 +196,7 @@ class MetricLogger(object): ...@@ -196,7 +196,7 @@ class MetricLogger(object):
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {}".format(header, total_time_str)) print(f"{header} Total time: {total_time_str}")
def cat_list(images, fill_value=0): def cat_list(images, fill_value=0):
...@@ -287,7 +287,7 @@ def init_distributed_mode(args): ...@@ -287,7 +287,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl" args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment