Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
130 additions
and
134 deletions
+130
-134
.circleci/unittest/linux/scripts/run-clang-format.py
.circleci/unittest/linux/scripts/run-clang-format.py
+13
-14
.pre-commit-config.yaml
.pre-commit-config.yaml
+17
-8
docs/source/conf.py
docs/source/conf.py
+0
-1
gallery/plot_scripted_tensor_transforms.py
gallery/plot_scripted_tensor_transforms.py
+1
-1
gallery/plot_video_api.py
gallery/plot_video_api.py
+1
-1
packaging/wheel/relocate.py
packaging/wheel/relocate.py
+18
-24
references/classification/train.py
references/classification/train.py
+10
-10
references/classification/train_quantization.py
references/classification/train_quantization.py
+3
-3
references/classification/transforms.py
references/classification/transforms.py
+14
-14
references/classification/utils.py
references/classification/utils.py
+7
-8
references/detection/coco_utils.py
references/detection/coco_utils.py
+4
-4
references/detection/engine.py
references/detection/engine.py
+2
-2
references/detection/group_by_aspect_ratio.py
references/detection/group_by_aspect_ratio.py
+3
-5
references/detection/train.py
references/detection/train.py
+4
-5
references/detection/transforms.py
references/detection/transforms.py
+5
-5
references/detection/utils.py
references/detection/utils.py
+6
-6
references/segmentation/coco_utils.py
references/segmentation/coco_utils.py
+2
-2
references/segmentation/train.py
references/segmentation/train.py
+4
-5
references/segmentation/transforms.py
references/segmentation/transforms.py
+6
-6
references/segmentation/utils.py
references/segmentation/utils.py
+10
-10
No files found.
.circleci/unittest/linux/scripts/run-clang-format.py
View file @
d367a01a
...
...
@@ -34,7 +34,6 @@ A diff output is produced and a sensible exit code is returned.
import
argparse
import
difflib
import
fnmatch
import
io
import
multiprocessing
import
os
import
signal
...
...
@@ -87,20 +86,20 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def
make_diff
(
file
,
original
,
reformatted
):
return
list
(
difflib
.
unified_diff
(
original
,
reformatted
,
fromfile
=
"{}
\t
(original)"
.
format
(
file
)
,
tofile
=
"{}
\t
(reformatted)"
.
format
(
file
)
,
n
=
3
original
,
reformatted
,
fromfile
=
f
"
{
file
}
\t
(original)"
,
tofile
=
f
"
{
file
}
\t
(reformatted)"
,
n
=
3
)
)
class
DiffError
(
Exception
):
def
__init__
(
self
,
message
,
errs
=
None
):
super
(
DiffError
,
self
).
__init__
(
message
)
super
().
__init__
(
message
)
self
.
errs
=
errs
or
[]
class
UnexpectedError
(
Exception
):
def
__init__
(
self
,
message
,
exc
=
None
):
super
(
UnexpectedError
,
self
).
__init__
(
message
)
super
().
__init__
(
message
)
self
.
formatted_traceback
=
traceback
.
format_exc
()
self
.
exc
=
exc
...
...
@@ -112,14 +111,14 @@ def run_clang_format_diff_wrapper(args, file):
except
DiffError
:
raise
except
Exception
as
e
:
raise
UnexpectedError
(
"{}: {
}: {}"
.
format
(
file
,
e
.
__class__
.
__name__
,
e
)
,
e
)
raise
UnexpectedError
(
f
"
{
file
}
:
{
e
.
__class__
.
__name__
}
:
{
e
}
"
,
e
)
def
run_clang_format_diff
(
args
,
file
):
try
:
with
io
.
open
(
file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
file
,
encoding
=
"utf-8"
)
as
f
:
original
=
f
.
readlines
()
except
I
OError
as
exc
:
except
O
S
Error
as
exc
:
raise
DiffError
(
str
(
exc
))
invocation
=
[
args
.
clang_format_executable
,
file
]
...
...
@@ -145,7 +144,7 @@ def run_clang_format_diff(args, file):
invocation
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
universal_newlines
=
True
,
encoding
=
"utf-8"
)
except
OSError
as
exc
:
raise
DiffError
(
"Command '{
}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
invocation
)
,
exc
)
)
raise
DiffError
(
f
"Command '
{
subprocess
.
list2cmdline
(
invocation
)
}
' failed to start:
{
exc
}
"
)
proc_stdout
=
proc
.
stdout
proc_stderr
=
proc
.
stderr
...
...
@@ -203,7 +202,7 @@ def print_trouble(prog, message, use_colors):
error_text
=
"error:"
if
use_colors
:
error_text
=
bold_red
(
error_text
)
print
(
"{}: {
} {}"
.
format
(
prog
,
error_text
,
message
)
,
file
=
sys
.
stderr
)
print
(
f
"
{
prog
}
:
{
error_text
}
{
message
}
"
,
file
=
sys
.
stderr
)
def
main
():
...
...
@@ -216,7 +215,7 @@ def main():
)
parser
.
add_argument
(
"--extensions"
,
help
=
"comma separated list of file extensions (default: {
})"
.
format
(
DEFAULT_EXTENSIONS
)
,
help
=
f
"comma separated list of file extensions (default:
{
DEFAULT_EXTENSIONS
}
)"
,
default
=
DEFAULT_EXTENSIONS
,
)
parser
.
add_argument
(
"-r"
,
"--recursive"
,
action
=
"store_true"
,
help
=
"run recursively over directories"
)
...
...
@@ -227,7 +226,7 @@ def main():
metavar
=
"N"
,
type
=
int
,
default
=
0
,
help
=
"run N clang-format jobs in parallel
"
"
(default number of cpus + 1)"
,
help
=
"run N clang-format jobs in parallel (default number of cpus + 1)"
,
)
parser
.
add_argument
(
"--color"
,
default
=
"auto"
,
choices
=
[
"auto"
,
"always"
,
"never"
],
help
=
"show colored diff (default: auto)"
...
...
@@ -238,7 +237,7 @@ def main():
metavar
=
"PATTERN"
,
action
=
"append"
,
default
=
[],
help
=
"exclude paths matching the given glob-like pattern(s)
"
"
from recursive search"
,
help
=
"exclude paths matching the given glob-like pattern(s) from recursive search"
,
)
args
=
parser
.
parse_args
()
...
...
@@ -263,7 +262,7 @@ def main():
colored_stdout
=
sys
.
stdout
.
isatty
()
colored_stderr
=
sys
.
stderr
.
isatty
()
version_invocation
=
[
args
.
clang_format_executable
,
str
(
"--version"
)
]
version_invocation
=
[
args
.
clang_format_executable
,
"--version"
]
try
:
subprocess
.
check_call
(
version_invocation
,
stdout
=
DEVNULL
)
except
subprocess
.
CalledProcessError
as
e
:
...
...
@@ -272,7 +271,7 @@ def main():
except
OSError
as
e
:
print_trouble
(
parser
.
prog
,
"Command '{
}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
version_invocation
)
,
e
)
,
f
"Command '
{
subprocess
.
list2cmdline
(
version_invocation
)
}
' failed to start:
{
e
}
"
,
use_colors
=
colored_stderr
,
)
return
ExitStatus
.
TROUBLE
...
...
.pre-commit-config.yaml
View file @
d367a01a
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.0.1
hooks
:
-
id
:
check-docstring-first
-
id
:
check-toml
-
id
:
check-yaml
exclude
:
packaging/.*
-
id
:
end-of-file-fixer
# - repo: https://github.com/asottile/pyupgrade
# rev: v2.29.0
# hooks:
# - id: pyupgrade
# args: [--py36-plus]
# name: Upgrade code
-
repo
:
https://github.com/omnilib/ufmt
rev
:
v1.3.0
hooks
:
...
...
@@ -6,16 +22,9 @@ repos:
additional_dependencies
:
-
black == 21.9b0
-
usort == 0.6.4
-
repo
:
https://gitlab.com/pycqa/flake8
rev
:
3.9.2
hooks
:
-
id
:
flake8
args
:
[
--config=setup.cfg
]
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.0.1
hooks
:
-
id
:
check-docstring-first
-
id
:
check-toml
-
id
:
check-yaml
exclude
:
packaging/.*
-
id
:
end-of-file-fixer
docs/source/conf.py
View file @
d367a01a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
...
...
gallery/plot_scripted_tensor_transforms.py
View file @
d367a01a
...
...
@@ -125,7 +125,7 @@ res_scripted = scripted_predictor(batch)
import
json
with
open
(
Path
(
'assets'
)
/
'imagenet_class_index.json'
,
'r'
)
as
labels_file
:
with
open
(
Path
(
'assets'
)
/
'imagenet_class_index.json'
)
as
labels_file
:
labels
=
json
.
load
(
labels_file
)
for
i
,
(
pred
,
pred_scripted
)
in
enumerate
(
zip
(
res
,
res_scripted
)):
...
...
gallery/plot_video_api.py
View file @
d367a01a
...
...
@@ -137,7 +137,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
if
end
<
start
:
raise
ValueError
(
"end time should be larger than start time, got "
"start time={} and end time={
}"
.
format
(
start
,
end
)
f
"start time=
{
start
}
and end time=
{
end
}
"
)
video_frames
=
torch
.
empty
(
0
)
...
...
packaging/wheel/relocate.py
View file @
d367a01a
# -*- coding: utf-8 -*-
"""Helper script to package wheels and relocate binaries."""
import
glob
...
...
@@ -157,7 +155,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
rename and copy them into the wheel while updating their respective rpaths.
"""
print
(
"Relocating {
0}"
.
format
(
binary
)
)
print
(
f
"Relocating
{
binary
}
"
)
binary_path
=
osp
.
join
(
output_library
,
binary
)
ld_tree
=
lddtree
(
binary_path
)
...
...
@@ -173,12 +171,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
print
(
library
)
if
library_info
[
"path"
]
is
None
:
print
(
"Omitting {
0}"
.
format
(
library
)
)
print
(
f
"Omitting
{
library
}
"
)
continue
if
library
in
ALLOWLIST
:
# Omit glibc/gcc/system libraries
print
(
"Omitting {
0}"
.
format
(
library
)
)
print
(
f
"Omitting
{
library
}
"
)
continue
parent_dependencies
=
binary_dependencies
.
get
(
parent
,
[])
...
...
@@ -201,7 +199,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if
library
!=
binary
:
library_path
=
binary_paths
[
library
]
new_library_path
=
patch_new_path
(
library_path
,
new_libraries_path
)
print
(
"{
0} -> {1}"
.
format
(
library
,
new_library_path
)
)
print
(
f
"
{
library
}
->
{
new_library_path
}
"
)
shutil
.
copyfile
(
library_path
,
new_library_path
)
new_names
[
library
]
=
new_library_path
...
...
@@ -214,7 +212,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
new_library_name
=
new_names
[
library
]
for
dep
in
library_dependencies
:
new_dep
=
osp
.
basename
(
new_names
[
dep
])
print
(
"{
0}: {1} -> {2}"
.
format
(
library
,
dep
,
new_dep
)
)
print
(
f
"
{
library
}
:
{
dep
}
->
{
new_dep
}
"
)
subprocess
.
check_output
(
[
patchelf
,
"--replace-needed"
,
dep
,
new_dep
,
new_library_name
],
cwd
=
new_libraries_path
)
...
...
@@ -228,7 +226,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
library_dependencies
=
binary_dependencies
[
binary
]
for
dep
in
library_dependencies
:
new_dep
=
osp
.
basename
(
new_names
[
dep
])
print
(
"{
0}: {1} -> {2}"
.
format
(
binary
,
dep
,
new_dep
)
)
print
(
f
"
{
binary
}
:
{
dep
}
->
{
new_dep
}
"
)
subprocess
.
check_output
([
patchelf
,
"--replace-needed"
,
dep
,
new_dep
,
binary
],
cwd
=
output_library
)
print
(
"Update library rpath"
)
...
...
@@ -244,7 +242,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
Given a shared library, find the transitive closure of its dependencies,
rename and copy them into the wheel.
"""
print
(
"Relocating {
0}"
.
format
(
binary
)
)
print
(
f
"Relocating
{
binary
}
"
)
binary_path
=
osp
.
join
(
output_library
,
binary
)
library_dlls
=
find_dll_dependencies
(
dumpbin
,
binary_path
)
...
...
@@ -255,18 +253,18 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
while
binary_queue
!=
[]:
library
,
parent
=
binary_queue
.
pop
(
0
)
if
library
in
WINDOWS_ALLOWLIST
or
library
.
startswith
(
"api-ms-win"
):
print
(
"Omitting {
0}"
.
format
(
library
)
)
print
(
f
"Omitting
{
library
}
"
)
continue
library_path
=
find_program
(
library
)
if
library_path
is
None
:
print
(
"{
0
} not found"
.
format
(
library
)
)
print
(
f
"
{
library
}
not found"
)
continue
if
osp
.
basename
(
osp
.
dirname
(
library_path
))
==
"system32"
:
continue
print
(
"{
0}: {1}"
.
format
(
library
,
library_path
)
)
print
(
f
"
{
library
}
:
{
library_path
}
"
)
parent_dependencies
=
binary_dependencies
.
get
(
parent
,
[])
parent_dependencies
.
append
(
library
)
binary_dependencies
[
parent
]
=
parent_dependencies
...
...
@@ -284,7 +282,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
if
library
!=
binary
:
library_path
=
binary_paths
[
library
]
new_library_path
=
osp
.
join
(
package_dir
,
library
)
print
(
"{
0} -> {1}"
.
format
(
library
,
new_library_path
)
)
print
(
f
"
{
library
}
->
{
new_library_path
}
"
)
shutil
.
copyfile
(
library_path
,
new_library_path
)
...
...
@@ -300,16 +298,16 @@ def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
full_file
=
osp
.
join
(
root
,
this_file
)
rel_file
=
osp
.
relpath
(
full_file
,
output_dir
)
if
full_file
==
record_file
:
f
.
write
(
"{
0},,
\n
"
.
format
(
rel_file
)
)
f
.
write
(
f
"
{
rel_file
}
,,
\n
"
)
else
:
digest
,
size
=
rehash
(
full_file
)
f
.
write
(
"{
0},{1},{2}
\n
"
.
format
(
rel_file
,
digest
,
size
)
)
f
.
write
(
f
"
{
rel_file
}
,
{
digest
}
,
{
size
}
\n
"
)
print
(
"Compressing wheel"
)
base_wheel_name
=
osp
.
join
(
wheel_dir
,
wheel_name
)
shutil
.
make_archive
(
base_wheel_name
,
"zip"
,
output_dir
)
os
.
remove
(
wheel
)
shutil
.
move
(
"{
0}.zip"
.
format
(
base_wheel_name
)
,
wheel
)
shutil
.
move
(
f
"
{
base_wheel_name
}
.zip"
,
wheel
)
shutil
.
rmtree
(
output_dir
)
...
...
@@ -317,9 +315,7 @@ def patch_linux():
# Get patchelf location
patchelf
=
find_program
(
"patchelf"
)
if
patchelf
is
None
:
raise
FileNotFoundError
(
"Patchelf was not found in the system, please"
" make sure that is available on the PATH."
)
raise
FileNotFoundError
(
"Patchelf was not found in the system, please make sure that is available on the PATH."
)
# Find wheel
print
(
"Finding wheels..."
)
...
...
@@ -338,7 +334,7 @@ def patch_linux():
print
(
"Unzipping wheel..."
)
wheel_file
=
osp
.
basename
(
wheel
)
wheel_dir
=
osp
.
dirname
(
wheel
)
print
(
"{
0}"
.
format
(
wheel_file
)
)
print
(
f
"
{
wheel_file
}
"
)
wheel_name
,
_
=
osp
.
splitext
(
wheel_file
)
unzip_file
(
wheel
,
output_dir
)
...
...
@@ -355,9 +351,7 @@ def patch_win():
# Get dumpbin location
dumpbin
=
find_program
(
"dumpbin"
)
if
dumpbin
is
None
:
raise
FileNotFoundError
(
"Dumpbin was not found in the system, please"
" make sure that is available on the PATH."
)
raise
FileNotFoundError
(
"Dumpbin was not found in the system, please make sure that is available on the PATH."
)
# Find wheel
print
(
"Finding wheels..."
)
...
...
@@ -376,7 +370,7 @@ def patch_win():
print
(
"Unzipping wheel..."
)
wheel_file
=
osp
.
basename
(
wheel
)
wheel_dir
=
osp
.
dirname
(
wheel
)
print
(
"{
0}"
.
format
(
wheel_file
)
)
print
(
f
"
{
wheel_file
}
"
)
wheel_name
,
_
=
osp
.
splitext
(
wheel_file
)
unzip_file
(
wheel
,
output_dir
)
...
...
references/classification/train.py
View file @
d367a01a
...
...
@@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value}"
))
metric_logger
.
add_meter
(
"img/s"
,
utils
.
SmoothedValue
(
window_size
=
10
,
fmt
=
"{value}"
))
header
=
"Epoch: [{
}]"
.
format
(
epoch
)
header
=
f
"Epoch: [
{
epoch
}
]"
for
i
,
(
image
,
target
)
in
enumerate
(
metric_logger
.
log_every
(
data_loader
,
args
.
print_freq
,
header
)):
start_time
=
time
.
time
()
image
,
target
=
image
.
to
(
device
),
target
.
to
(
device
)
...
...
@@ -121,7 +121,7 @@ def load_data(traindir, valdir, args):
cache_path
=
_get_cache_path
(
traindir
)
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
# Attention, as the transforms are also cached!
print
(
"Loading dataset_train from {
}"
.
format
(
cache_path
)
)
print
(
f
"Loading dataset_train from
{
cache_path
}
"
)
dataset
,
_
=
torch
.
load
(
cache_path
)
else
:
auto_augment_policy
=
getattr
(
args
,
"auto_augment"
,
None
)
...
...
@@ -136,7 +136,7 @@ def load_data(traindir, valdir, args):
),
)
if
args
.
cache_dataset
:
print
(
"Saving dataset_train to {
}"
.
format
(
cache_path
)
)
print
(
f
"Saving dataset_train to
{
cache_path
}
"
)
utils
.
mkdir
(
os
.
path
.
dirname
(
cache_path
))
utils
.
save_on_master
((
dataset
,
traindir
),
cache_path
)
print
(
"Took"
,
time
.
time
()
-
st
)
...
...
@@ -145,7 +145,7 @@ def load_data(traindir, valdir, args):
cache_path
=
_get_cache_path
(
valdir
)
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
# Attention, as the transforms are also cached!
print
(
"Loading dataset_test from {
}"
.
format
(
cache_path
)
)
print
(
f
"Loading dataset_test from
{
cache_path
}
"
)
dataset_test
,
_
=
torch
.
load
(
cache_path
)
else
:
if
not
args
.
weights
:
...
...
@@ -162,7 +162,7 @@ def load_data(traindir, valdir, args):
preprocessing
,
)
if
args
.
cache_dataset
:
print
(
"Saving dataset_test to {
}"
.
format
(
cache_path
)
)
print
(
f
"Saving dataset_test to
{
cache_path
}
"
)
utils
.
mkdir
(
os
.
path
.
dirname
(
cache_path
))
utils
.
save_on_master
((
dataset_test
,
valdir
),
cache_path
)
...
...
@@ -270,8 +270,8 @@ def main(args):
main_lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ExponentialLR
(
optimizer
,
gamma
=
args
.
lr_gamma
)
else
:
raise
RuntimeError
(
"Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
.
format
(
args
.
lr_scheduler
)
f
"Invalid lr scheduler '
{
args
.
lr_scheduler
}
'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
)
if
args
.
lr_warmup_epochs
>
0
:
...
...
@@ -285,7 +285,7 @@ def main(args):
)
else
:
raise
RuntimeError
(
f
"Invalid warmup lr method '
{
args
.
lr_warmup_method
}
'. Only linear and constant
"
"
are supported."
f
"Invalid warmup lr method '
{
args
.
lr_warmup_method
}
'. Only linear and constant are supported."
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
SequentialLR
(
optimizer
,
schedulers
=
[
warmup_lr_scheduler
,
main_lr_scheduler
],
milestones
=
[
args
.
lr_warmup_epochs
]
...
...
@@ -351,12 +351,12 @@ def main(args):
}
if
model_ema
:
checkpoint
[
"model_ema"
]
=
model_ema
.
state_dict
()
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
def
get_args_parser
(
add_help
=
True
):
...
...
references/classification/train_quantization.py
View file @
d367a01a
...
...
@@ -20,7 +20,7 @@ def main(args):
print
(
args
)
if
args
.
post_training_quantize
and
args
.
distributed
:
raise
RuntimeError
(
"Post training quantization example should not be performed
"
"
on distributed mode"
)
raise
RuntimeError
(
"Post training quantization example should not be performed on distributed mode"
)
# Set backend engine to ensure that quantized model runs on the correct kernels
if
args
.
backend
not
in
torch
.
backends
.
quantized
.
supported_engines
:
...
...
@@ -141,13 +141,13 @@ def main(args):
"epoch"
:
epoch
,
"args"
:
args
,
}
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
print
(
"Saving models after epoch "
,
epoch
)
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
def
get_args_parser
(
add_help
=
True
):
...
...
references/classification/transforms.py
View file @
d367a01a
...
...
@@ -39,13 +39,13 @@ class RandomMixup(torch.nn.Module):
Tensor: Randomly transformed batch.
"""
if
batch
.
ndim
!=
4
:
raise
ValueError
(
"Batch ndim should be 4. Got {
}"
.
format
(
batch
.
ndim
)
)
el
if
target
.
ndim
!=
1
:
raise
ValueError
(
"Target ndim should be 1. Got {
}"
.
format
(
target
.
ndim
)
)
el
if
not
batch
.
is_floating_point
():
raise
TypeError
(
"Batch dtype should be a float tensor. Got {
}."
.
format
(
batch
.
dtype
)
)
el
if
target
.
dtype
!=
torch
.
int64
:
raise
TypeError
(
"Target dtype should be torch.int64. Got {
}"
.
format
(
target
.
dtype
)
)
raise
ValueError
(
f
"Batch ndim should be 4. Got
{
batch
.
ndim
}
"
)
if
target
.
ndim
!=
1
:
raise
ValueError
(
f
"Target ndim should be 1. Got
{
target
.
ndim
}
"
)
if
not
batch
.
is_floating_point
():
raise
TypeError
(
f
"Batch dtype should be a float tensor. Got
{
batch
.
dtype
}
."
)
if
target
.
dtype
!=
torch
.
int64
:
raise
TypeError
(
f
"Target dtype should be torch.int64. Got
{
target
.
dtype
}
"
)
if
not
self
.
inplace
:
batch
=
batch
.
clone
()
...
...
@@ -115,13 +115,13 @@ class RandomCutmix(torch.nn.Module):
Tensor: Randomly transformed batch.
"""
if
batch
.
ndim
!=
4
:
raise
ValueError
(
"Batch ndim should be 4. Got {
}"
.
format
(
batch
.
ndim
)
)
el
if
target
.
ndim
!=
1
:
raise
ValueError
(
"Target ndim should be 1. Got {
}"
.
format
(
target
.
ndim
)
)
el
if
not
batch
.
is_floating_point
():
raise
TypeError
(
"Batch dtype should be a float tensor. Got {
}."
.
format
(
batch
.
dtype
)
)
el
if
target
.
dtype
!=
torch
.
int64
:
raise
TypeError
(
"Target dtype should be torch.int64. Got {
}"
.
format
(
target
.
dtype
)
)
raise
ValueError
(
f
"Batch ndim should be 4. Got
{
batch
.
ndim
}
"
)
if
target
.
ndim
!=
1
:
raise
ValueError
(
f
"Target ndim should be 1. Got
{
target
.
ndim
}
"
)
if
not
batch
.
is_floating_point
():
raise
TypeError
(
f
"Batch dtype should be a float tensor. Got
{
batch
.
dtype
}
."
)
if
target
.
dtype
!=
torch
.
int64
:
raise
TypeError
(
f
"Target dtype should be torch.int64. Got
{
target
.
dtype
}
"
)
if
not
self
.
inplace
:
batch
=
batch
.
clone
()
...
...
references/classification/utils.py
View file @
d367a01a
...
...
@@ -10,7 +10,7 @@ import torch
import
torch.distributed
as
dist
class
SmoothedValue
(
object
)
:
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
...
...
@@ -65,7 +65,7 @@ class SmoothedValue(object):
)
class
MetricLogger
(
object
)
:
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
...
...
@@ -82,12 +82,12 @@ class MetricLogger(object):
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{
}'"
.
format
(
type
(
self
).
__name__
,
attr
)
)
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {
}"
.
format
(
name
,
str
(
meter
)
)
)
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
...
...
@@ -152,7 +152,7 @@ class MetricLogger(object):
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"{} Total time: {
}"
.
format
(
header
,
total_time_str
)
)
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
"
)
class
ExponentialMovingAverage
(
torch
.
optim
.
swa_utils
.
AveragedModel
):
...
...
@@ -270,7 +270,7 @@ def init_distributed_mode(args):
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {
}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
)
,
flush
=
True
)
print
(
f
"| distributed init (rank
{
args
.
rank
}
):
{
args
.
dist_url
}
"
,
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
...
...
@@ -307,8 +307,7 @@ def average_checkpoints(inputs):
params_keys
=
model_params_keys
elif
params_keys
!=
model_params_keys
:
raise
KeyError
(
"For checkpoint {}, expected list of params: {}, "
"but found: {}"
.
format
(
f
,
params_keys
,
model_params_keys
)
f
"For checkpoint
{
f
}
, expected list of params:
{
params_keys
}
, but found:
{
model_params_keys
}
"
)
for
k
in
params_keys
:
p
=
model_params
[
k
]
...
...
references/detection/coco_utils.py
View file @
d367a01a
...
...
@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from
pycocotools.coco
import
COCO
class
FilterAndRemapCocoCategories
(
object
)
:
class
FilterAndRemapCocoCategories
:
def
__init__
(
self
,
categories
,
remap
=
True
):
self
.
categories
=
categories
self
.
remap
=
remap
...
...
@@ -44,7 +44,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return
masks
class
ConvertCocoPolysToMask
(
object
)
:
class
ConvertCocoPolysToMask
:
def
__call__
(
self
,
image
,
target
):
w
,
h
=
image
.
size
...
...
@@ -205,11 +205,11 @@ def get_coco_api_from_dataset(dataset):
class
CocoDetection
(
torchvision
.
datasets
.
CocoDetection
):
def
__init__
(
self
,
img_folder
,
ann_file
,
transforms
):
super
(
CocoDetection
,
self
).
__init__
(
img_folder
,
ann_file
)
super
().
__init__
(
img_folder
,
ann_file
)
self
.
_transforms
=
transforms
def
__getitem__
(
self
,
idx
):
img
,
target
=
super
(
CocoDetection
,
self
).
__getitem__
(
idx
)
img
,
target
=
super
().
__getitem__
(
idx
)
image_id
=
self
.
ids
[
idx
]
target
=
dict
(
image_id
=
image_id
,
annotations
=
target
)
if
self
.
_transforms
is
not
None
:
...
...
references/detection/engine.py
View file @
d367a01a
...
...
@@ -13,7 +13,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model
.
train
()
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value:.6f}"
))
header
=
"Epoch: [{
}]"
.
format
(
epoch
)
header
=
f
"Epoch: [
{
epoch
}
]"
lr_scheduler
=
None
if
epoch
==
0
:
...
...
@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
loss_value
=
losses_reduced
.
item
()
if
not
math
.
isfinite
(
loss_value
):
print
(
"Loss is {}, stopping training"
.
format
(
loss_value
)
)
print
(
f
"Loss is
{
loss_value
}
, stopping training"
)
print
(
loss_dict_reduced
)
sys
.
exit
(
1
)
...
...
references/detection/group_by_aspect_ratio.py
View file @
d367a01a
...
...
@@ -36,9 +36,7 @@ class GroupedBatchSampler(BatchSampler):
def
__init__
(
self
,
sampler
,
group_ids
,
batch_size
):
if
not
isinstance
(
sampler
,
Sampler
):
raise
ValueError
(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.
format
(
sampler
)
)
raise
ValueError
(
f
"sampler should be an instance of torch.utils.data.Sampler, but got sampler=
{
sampler
}
"
)
self
.
sampler
=
sampler
self
.
group_ids
=
group_ids
self
.
batch_size
=
batch_size
...
...
@@ -193,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
# count number of elements per group
counts
=
np
.
unique
(
groups
,
return_counts
=
True
)[
1
]
fbins
=
[
0
]
+
bins
+
[
np
.
inf
]
print
(
"Using {} as bins for aspect ratio quantization"
.
format
(
fbins
)
)
print
(
"Count of instances per bin: {
}"
.
format
(
counts
)
)
print
(
f
"Using
{
fbins
}
as bins for aspect ratio quantization"
)
print
(
f
"Count of instances per bin:
{
counts
}
"
)
return
groups
references/detection/train.py
View file @
d367a01a
...
...
@@ -65,7 +65,7 @@ def get_args_parser(add_help=True):
"--lr"
,
default
=
0.02
,
type
=
float
,
help
=
"initial learning rate, 0.02 is the default value for training
"
"
on 8 gpus and 2 images_per_gpu"
,
help
=
"initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu"
,
)
parser
.
add_argument
(
"--momentum"
,
default
=
0.9
,
type
=
float
,
metavar
=
"M"
,
help
=
"momentum"
)
parser
.
add_argument
(
...
...
@@ -197,8 +197,7 @@ def main(args):
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
args
.
epochs
)
else
:
raise
RuntimeError
(
"Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
"are supported."
.
format
(
args
.
lr_scheduler
)
f
"Invalid lr scheduler '
{
args
.
lr_scheduler
}
'. Only MultiStepLR and CosineAnnealingLR are supported."
)
if
args
.
resume
:
...
...
@@ -227,7 +226,7 @@ def main(args):
"args"
:
args
,
"epoch"
:
epoch
,
}
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
# evaluate after every epoch
...
...
@@ -235,7 +234,7 @@ def main(args):
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
if
__name__
==
"__main__"
:
...
...
references/detection/transforms.py
View file @
d367a01a
...
...
@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
return
flipped_data
class
Compose
(
object
)
:
class
Compose
:
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
...
...
@@ -103,7 +103,7 @@ class RandomIoUCrop(nn.Module):
if
isinstance
(
image
,
torch
.
Tensor
):
if
image
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"image should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
image
.
n
dimension
())
)
raise
ValueError
(
f
"image should be 2/3 dimensional. Got
{
image
.
n
dimension
()
}
dimension
s."
)
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
...
...
@@ -171,7 +171,7 @@ class RandomZoomOut(nn.Module):
self
.
fill
=
fill
self
.
side_range
=
side_range
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
"Invalid canvas side range provided {
}."
.
format
(
side_range
)
)
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
@
torch
.
jit
.
unused
...
...
@@ -185,7 +185,7 @@ class RandomZoomOut(nn.Module):
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
if
isinstance
(
image
,
torch
.
Tensor
):
if
image
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"image should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
image
.
n
dimension
())
)
raise
ValueError
(
f
"image should be 2/3 dimensional. Got
{
image
.
n
dimension
()
}
dimension
s."
)
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
...
...
@@ -244,7 +244,7 @@ class RandomPhotometricDistort(nn.Module):
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
if
isinstance
(
image
,
torch
.
Tensor
):
if
image
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"image should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
image
.
n
dimension
())
)
raise
ValueError
(
f
"image should be 2/3 dimensional. Got
{
image
.
n
dimension
()
}
dimension
s."
)
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
...
...
references/detection/utils.py
View file @
d367a01a
...
...
@@ -8,7 +8,7 @@ import torch
import
torch.distributed
as
dist
class
SmoothedValue
(
object
)
:
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
...
...
@@ -110,7 +110,7 @@ def reduce_dict(input_dict, average=True):
return
reduced_dict
class
MetricLogger
(
object
)
:
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
...
...
@@ -127,12 +127,12 @@ class MetricLogger(object):
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{
}'"
.
format
(
type
(
self
).
__name__
,
attr
)
)
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {
}"
.
format
(
name
,
str
(
meter
)
)
)
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
...
...
@@ -197,7 +197,7 @@ class MetricLogger(object):
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"{} Total time: {
} ({:.4f} s / it)"
.
format
(
header
,
total_time_str
,
total_time
/
len
(
iterable
)
)
)
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
(
{
total_time
/
len
(
iterable
)
:.
4
f
}
s / it)"
)
def
collate_fn
(
batch
):
...
...
@@ -274,7 +274,7 @@ def init_distributed_mode(args):
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {
}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
)
,
flush
=
True
)
print
(
f
"| distributed init (rank
{
args
.
rank
}
):
{
args
.
dist_url
}
"
,
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
...
...
references/segmentation/coco_utils.py
View file @
d367a01a
...
...
@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from
transforms
import
Compose
class
FilterAndRemapCocoCategories
(
object
)
:
class
FilterAndRemapCocoCategories
:
def
__init__
(
self
,
categories
,
remap
=
True
):
self
.
categories
=
categories
self
.
remap
=
remap
...
...
@@ -41,7 +41,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return
masks
class
ConvertCocoPolysToMask
(
object
)
:
class
ConvertCocoPolysToMask
:
def
__call__
(
self
,
image
,
anno
):
w
,
h
=
image
.
size
segmentations
=
[
obj
[
"segmentation"
]
for
obj
in
anno
]
...
...
references/segmentation/train.py
View file @
d367a01a
...
...
@@ -66,7 +66,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
model
.
train
()
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value}"
))
header
=
"Epoch: [{
}]"
.
format
(
epoch
)
header
=
f
"Epoch: [
{
epoch
}
]"
for
image
,
target
in
metric_logger
.
log_every
(
data_loader
,
print_freq
,
header
):
image
,
target
=
image
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
image
)
...
...
@@ -152,8 +152,7 @@ def main(args):
)
else
:
raise
RuntimeError
(
"Invalid warmup lr method '{}'. Only linear and constant "
"are supported."
.
format
(
args
.
lr_warmup_method
)
f
"Invalid warmup lr method '
{
args
.
lr_warmup_method
}
'. Only linear and constant are supported."
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
SequentialLR
(
optimizer
,
schedulers
=
[
warmup_lr_scheduler
,
main_lr_scheduler
],
milestones
=
[
warmup_iters
]
...
...
@@ -188,12 +187,12 @@ def main(args):
"epoch"
:
epoch
,
"args"
:
args
,
}
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
def
get_args_parser
(
add_help
=
True
):
...
...
references/segmentation/transforms.py
View file @
d367a01a
...
...
@@ -16,7 +16,7 @@ def pad_if_smaller(img, size, fill=0):
return
img
class
Compose
(
object
)
:
class
Compose
:
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
...
...
@@ -26,7 +26,7 @@ class Compose(object):
return
image
,
target
class
RandomResize
(
object
)
:
class
RandomResize
:
def
__init__
(
self
,
min_size
,
max_size
=
None
):
self
.
min_size
=
min_size
if
max_size
is
None
:
...
...
@@ -40,7 +40,7 @@ class RandomResize(object):
return
image
,
target
class
RandomHorizontalFlip
(
object
)
:
class
RandomHorizontalFlip
:
def
__init__
(
self
,
flip_prob
):
self
.
flip_prob
=
flip_prob
...
...
@@ -51,7 +51,7 @@ class RandomHorizontalFlip(object):
return
image
,
target
class
RandomCrop
(
object
)
:
class
RandomCrop
:
def
__init__
(
self
,
size
):
self
.
size
=
size
...
...
@@ -64,7 +64,7 @@ class RandomCrop(object):
return
image
,
target
class
CenterCrop
(
object
)
:
class
CenterCrop
:
def
__init__
(
self
,
size
):
self
.
size
=
size
...
...
@@ -90,7 +90,7 @@ class ConvertImageDtype:
return
image
,
target
class
Normalize
(
object
)
:
class
Normalize
:
def
__init__
(
self
,
mean
,
std
):
self
.
mean
=
mean
self
.
std
=
std
...
...
references/segmentation/utils.py
View file @
d367a01a
...
...
@@ -8,7 +8,7 @@ import torch
import
torch.distributed
as
dist
class
SmoothedValue
(
object
)
:
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
...
...
@@ -67,7 +67,7 @@ class SmoothedValue(object):
)
class
ConfusionMatrix
(
object
)
:
class
ConfusionMatrix
:
def
__init__
(
self
,
num_classes
):
self
.
num_classes
=
num_classes
self
.
mat
=
None
...
...
@@ -101,15 +101,15 @@ class ConfusionMatrix(object):
def
__str__
(
self
):
acc_global
,
acc
,
iu
=
self
.
compute
()
return
(
"global correct: {:.1f}
\n
"
"
average row correct: {}
\n
"
"
IoU: {}
\n
"
"
mean IoU: {:.1f}"
).
format
(
return
(
"global correct: {:.1f}
\n
average row correct: {}
\n
IoU: {}
\n
mean IoU: {:.1f}"
).
format
(
acc_global
.
item
()
*
100
,
[
"{:.1f}"
.
format
(
i
)
for
i
in
(
acc
*
100
).
tolist
()],
[
"{:.1f}"
.
format
(
i
)
for
i
in
(
iu
*
100
).
tolist
()],
[
f
"
{
i
:.
1
f
}
"
for
i
in
(
acc
*
100
).
tolist
()],
[
f
"
{
i
:.
1
f
}
"
for
i
in
(
iu
*
100
).
tolist
()],
iu
.
mean
().
item
()
*
100
,
)
class
MetricLogger
(
object
)
:
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
...
...
@@ -126,12 +126,12 @@ class MetricLogger(object):
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{
}'"
.
format
(
type
(
self
).
__name__
,
attr
)
)
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {
}"
.
format
(
name
,
str
(
meter
)
)
)
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
...
...
@@ -196,7 +196,7 @@ class MetricLogger(object):
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"{} Total time: {
}"
.
format
(
header
,
total_time_str
)
)
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
"
)
def
cat_list
(
images
,
fill_value
=
0
):
...
...
@@ -287,7 +287,7 @@ def init_distributed_mode(args):
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {
}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
)
,
flush
=
True
)
print
(
f
"| distributed init (rank
{
args
.
rank
}
):
{
args
.
dist_url
}
"
,
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
...
...
Prev
1
2
3
4
5
…
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment