Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
130 additions
and
134 deletions
+130
-134
.circleci/unittest/linux/scripts/run-clang-format.py
.circleci/unittest/linux/scripts/run-clang-format.py
+13
-14
.pre-commit-config.yaml
.pre-commit-config.yaml
+17
-8
docs/source/conf.py
docs/source/conf.py
+0
-1
gallery/plot_scripted_tensor_transforms.py
gallery/plot_scripted_tensor_transforms.py
+1
-1
gallery/plot_video_api.py
gallery/plot_video_api.py
+1
-1
packaging/wheel/relocate.py
packaging/wheel/relocate.py
+18
-24
references/classification/train.py
references/classification/train.py
+10
-10
references/classification/train_quantization.py
references/classification/train_quantization.py
+3
-3
references/classification/transforms.py
references/classification/transforms.py
+14
-14
references/classification/utils.py
references/classification/utils.py
+7
-8
references/detection/coco_utils.py
references/detection/coco_utils.py
+4
-4
references/detection/engine.py
references/detection/engine.py
+2
-2
references/detection/group_by_aspect_ratio.py
references/detection/group_by_aspect_ratio.py
+3
-5
references/detection/train.py
references/detection/train.py
+4
-5
references/detection/transforms.py
references/detection/transforms.py
+5
-5
references/detection/utils.py
references/detection/utils.py
+6
-6
references/segmentation/coco_utils.py
references/segmentation/coco_utils.py
+2
-2
references/segmentation/train.py
references/segmentation/train.py
+4
-5
references/segmentation/transforms.py
references/segmentation/transforms.py
+6
-6
references/segmentation/utils.py
references/segmentation/utils.py
+10
-10
No files found.
.circleci/unittest/linux/scripts/run-clang-format.py
View file @
d367a01a
...
@@ -34,7 +34,6 @@ A diff output is produced and a sensible exit code is returned.
...
@@ -34,7 +34,6 @@ A diff output is produced and a sensible exit code is returned.
import
argparse
import
argparse
import
difflib
import
difflib
import
fnmatch
import
fnmatch
import
io
import
multiprocessing
import
multiprocessing
import
os
import
os
import
signal
import
signal
...
@@ -87,20 +86,20 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
...
@@ -87,20 +86,20 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def
make_diff
(
file
,
original
,
reformatted
):
def
make_diff
(
file
,
original
,
reformatted
):
return
list
(
return
list
(
difflib
.
unified_diff
(
difflib
.
unified_diff
(
original
,
reformatted
,
fromfile
=
"{}
\t
(original)"
.
format
(
file
)
,
tofile
=
"{}
\t
(reformatted)"
.
format
(
file
)
,
n
=
3
original
,
reformatted
,
fromfile
=
f
"
{
file
}
\t
(original)"
,
tofile
=
f
"
{
file
}
\t
(reformatted)"
,
n
=
3
)
)
)
)
class
DiffError
(
Exception
):
class
DiffError
(
Exception
):
def
__init__
(
self
,
message
,
errs
=
None
):
def
__init__
(
self
,
message
,
errs
=
None
):
super
(
DiffError
,
self
).
__init__
(
message
)
super
().
__init__
(
message
)
self
.
errs
=
errs
or
[]
self
.
errs
=
errs
or
[]
class
UnexpectedError
(
Exception
):
class
UnexpectedError
(
Exception
):
def
__init__
(
self
,
message
,
exc
=
None
):
def
__init__
(
self
,
message
,
exc
=
None
):
super
(
UnexpectedError
,
self
).
__init__
(
message
)
super
().
__init__
(
message
)
self
.
formatted_traceback
=
traceback
.
format_exc
()
self
.
formatted_traceback
=
traceback
.
format_exc
()
self
.
exc
=
exc
self
.
exc
=
exc
...
@@ -112,14 +111,14 @@ def run_clang_format_diff_wrapper(args, file):
...
@@ -112,14 +111,14 @@ def run_clang_format_diff_wrapper(args, file):
except
DiffError
:
except
DiffError
:
raise
raise
except
Exception
as
e
:
except
Exception
as
e
:
raise
UnexpectedError
(
"{}: {
}: {}"
.
format
(
file
,
e
.
__class__
.
__name__
,
e
)
,
e
)
raise
UnexpectedError
(
f
"
{
file
}
:
{
e
.
__class__
.
__name__
}
:
{
e
}
"
,
e
)
def
run_clang_format_diff
(
args
,
file
):
def
run_clang_format_diff
(
args
,
file
):
try
:
try
:
with
io
.
open
(
file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
file
,
encoding
=
"utf-8"
)
as
f
:
original
=
f
.
readlines
()
original
=
f
.
readlines
()
except
I
OError
as
exc
:
except
O
S
Error
as
exc
:
raise
DiffError
(
str
(
exc
))
raise
DiffError
(
str
(
exc
))
invocation
=
[
args
.
clang_format_executable
,
file
]
invocation
=
[
args
.
clang_format_executable
,
file
]
...
@@ -145,7 +144,7 @@ def run_clang_format_diff(args, file):
...
@@ -145,7 +144,7 @@ def run_clang_format_diff(args, file):
invocation
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
universal_newlines
=
True
,
encoding
=
"utf-8"
invocation
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
universal_newlines
=
True
,
encoding
=
"utf-8"
)
)
except
OSError
as
exc
:
except
OSError
as
exc
:
raise
DiffError
(
"Command '{
}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
invocation
)
,
exc
)
)
raise
DiffError
(
f
"Command '
{
subprocess
.
list2cmdline
(
invocation
)
}
' failed to start:
{
exc
}
"
)
proc_stdout
=
proc
.
stdout
proc_stdout
=
proc
.
stdout
proc_stderr
=
proc
.
stderr
proc_stderr
=
proc
.
stderr
...
@@ -203,7 +202,7 @@ def print_trouble(prog, message, use_colors):
...
@@ -203,7 +202,7 @@ def print_trouble(prog, message, use_colors):
error_text
=
"error:"
error_text
=
"error:"
if
use_colors
:
if
use_colors
:
error_text
=
bold_red
(
error_text
)
error_text
=
bold_red
(
error_text
)
print
(
"{}: {
} {}"
.
format
(
prog
,
error_text
,
message
)
,
file
=
sys
.
stderr
)
print
(
f
"
{
prog
}
:
{
error_text
}
{
message
}
"
,
file
=
sys
.
stderr
)
def
main
():
def
main
():
...
@@ -216,7 +215,7 @@ def main():
...
@@ -216,7 +215,7 @@ def main():
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--extensions"
,
"--extensions"
,
help
=
"comma separated list of file extensions (default: {
})"
.
format
(
DEFAULT_EXTENSIONS
)
,
help
=
f
"comma separated list of file extensions (default:
{
DEFAULT_EXTENSIONS
}
)"
,
default
=
DEFAULT_EXTENSIONS
,
default
=
DEFAULT_EXTENSIONS
,
)
)
parser
.
add_argument
(
"-r"
,
"--recursive"
,
action
=
"store_true"
,
help
=
"run recursively over directories"
)
parser
.
add_argument
(
"-r"
,
"--recursive"
,
action
=
"store_true"
,
help
=
"run recursively over directories"
)
...
@@ -227,7 +226,7 @@ def main():
...
@@ -227,7 +226,7 @@ def main():
metavar
=
"N"
,
metavar
=
"N"
,
type
=
int
,
type
=
int
,
default
=
0
,
default
=
0
,
help
=
"run N clang-format jobs in parallel
"
"
(default number of cpus + 1)"
,
help
=
"run N clang-format jobs in parallel (default number of cpus + 1)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--color"
,
default
=
"auto"
,
choices
=
[
"auto"
,
"always"
,
"never"
],
help
=
"show colored diff (default: auto)"
"--color"
,
default
=
"auto"
,
choices
=
[
"auto"
,
"always"
,
"never"
],
help
=
"show colored diff (default: auto)"
...
@@ -238,7 +237,7 @@ def main():
...
@@ -238,7 +237,7 @@ def main():
metavar
=
"PATTERN"
,
metavar
=
"PATTERN"
,
action
=
"append"
,
action
=
"append"
,
default
=
[],
default
=
[],
help
=
"exclude paths matching the given glob-like pattern(s)
"
"
from recursive search"
,
help
=
"exclude paths matching the given glob-like pattern(s) from recursive search"
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -263,7 +262,7 @@ def main():
...
@@ -263,7 +262,7 @@ def main():
colored_stdout
=
sys
.
stdout
.
isatty
()
colored_stdout
=
sys
.
stdout
.
isatty
()
colored_stderr
=
sys
.
stderr
.
isatty
()
colored_stderr
=
sys
.
stderr
.
isatty
()
version_invocation
=
[
args
.
clang_format_executable
,
str
(
"--version"
)
]
version_invocation
=
[
args
.
clang_format_executable
,
"--version"
]
try
:
try
:
subprocess
.
check_call
(
version_invocation
,
stdout
=
DEVNULL
)
subprocess
.
check_call
(
version_invocation
,
stdout
=
DEVNULL
)
except
subprocess
.
CalledProcessError
as
e
:
except
subprocess
.
CalledProcessError
as
e
:
...
@@ -272,7 +271,7 @@ def main():
...
@@ -272,7 +271,7 @@ def main():
except
OSError
as
e
:
except
OSError
as
e
:
print_trouble
(
print_trouble
(
parser
.
prog
,
parser
.
prog
,
"Command '{
}' failed to start: {}"
.
format
(
subprocess
.
list2cmdline
(
version_invocation
)
,
e
)
,
f
"Command '
{
subprocess
.
list2cmdline
(
version_invocation
)
}
' failed to start:
{
e
}
"
,
use_colors
=
colored_stderr
,
use_colors
=
colored_stderr
,
)
)
return
ExitStatus
.
TROUBLE
return
ExitStatus
.
TROUBLE
...
...
.pre-commit-config.yaml
View file @
d367a01a
repos
:
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.0.1
hooks
:
-
id
:
check-docstring-first
-
id
:
check-toml
-
id
:
check-yaml
exclude
:
packaging/.*
-
id
:
end-of-file-fixer
# - repo: https://github.com/asottile/pyupgrade
# rev: v2.29.0
# hooks:
# - id: pyupgrade
# args: [--py36-plus]
# name: Upgrade code
-
repo
:
https://github.com/omnilib/ufmt
-
repo
:
https://github.com/omnilib/ufmt
rev
:
v1.3.0
rev
:
v1.3.0
hooks
:
hooks
:
...
@@ -6,16 +22,9 @@ repos:
...
@@ -6,16 +22,9 @@ repos:
additional_dependencies
:
additional_dependencies
:
-
black == 21.9b0
-
black == 21.9b0
-
usort == 0.6.4
-
usort == 0.6.4
-
repo
:
https://gitlab.com/pycqa/flake8
-
repo
:
https://gitlab.com/pycqa/flake8
rev
:
3.9.2
rev
:
3.9.2
hooks
:
hooks
:
-
id
:
flake8
-
id
:
flake8
args
:
[
--config=setup.cfg
]
args
:
[
--config=setup.cfg
]
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.0.1
hooks
:
-
id
:
check-docstring-first
-
id
:
check-toml
-
id
:
check-yaml
exclude
:
packaging/.*
-
id
:
end-of-file-fixer
docs/source/conf.py
View file @
d367a01a
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
#
# PyTorch documentation build configuration file, created by
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
...
...
gallery/plot_scripted_tensor_transforms.py
View file @
d367a01a
...
@@ -125,7 +125,7 @@ res_scripted = scripted_predictor(batch)
...
@@ -125,7 +125,7 @@ res_scripted = scripted_predictor(batch)
import
json
import
json
with
open
(
Path
(
'assets'
)
/
'imagenet_class_index.json'
,
'r'
)
as
labels_file
:
with
open
(
Path
(
'assets'
)
/
'imagenet_class_index.json'
)
as
labels_file
:
labels
=
json
.
load
(
labels_file
)
labels
=
json
.
load
(
labels_file
)
for
i
,
(
pred
,
pred_scripted
)
in
enumerate
(
zip
(
res
,
res_scripted
)):
for
i
,
(
pred
,
pred_scripted
)
in
enumerate
(
zip
(
res
,
res_scripted
)):
...
...
gallery/plot_video_api.py
View file @
d367a01a
...
@@ -137,7 +137,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
...
@@ -137,7 +137,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
if
end
<
start
:
if
end
<
start
:
raise
ValueError
(
raise
ValueError
(
"end time should be larger than start time, got "
"end time should be larger than start time, got "
"start time={} and end time={
}"
.
format
(
start
,
end
)
f
"start time=
{
start
}
and end time=
{
end
}
"
)
)
video_frames
=
torch
.
empty
(
0
)
video_frames
=
torch
.
empty
(
0
)
...
...
packaging/wheel/relocate.py
View file @
d367a01a
# -*- coding: utf-8 -*-
"""Helper script to package wheels and relocate binaries."""
"""Helper script to package wheels and relocate binaries."""
import
glob
import
glob
...
@@ -157,7 +155,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
...
@@ -157,7 +155,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
rename and copy them into the wheel while updating their respective rpaths.
rename and copy them into the wheel while updating their respective rpaths.
"""
"""
print
(
"Relocating {
0}"
.
format
(
binary
)
)
print
(
f
"Relocating
{
binary
}
"
)
binary_path
=
osp
.
join
(
output_library
,
binary
)
binary_path
=
osp
.
join
(
output_library
,
binary
)
ld_tree
=
lddtree
(
binary_path
)
ld_tree
=
lddtree
(
binary_path
)
...
@@ -173,12 +171,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
...
@@ -173,12 +171,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
print
(
library
)
print
(
library
)
if
library_info
[
"path"
]
is
None
:
if
library_info
[
"path"
]
is
None
:
print
(
"Omitting {
0}"
.
format
(
library
)
)
print
(
f
"Omitting
{
library
}
"
)
continue
continue
if
library
in
ALLOWLIST
:
if
library
in
ALLOWLIST
:
# Omit glibc/gcc/system libraries
# Omit glibc/gcc/system libraries
print
(
"Omitting {
0}"
.
format
(
library
)
)
print
(
f
"Omitting
{
library
}
"
)
continue
continue
parent_dependencies
=
binary_dependencies
.
get
(
parent
,
[])
parent_dependencies
=
binary_dependencies
.
get
(
parent
,
[])
...
@@ -201,7 +199,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
...
@@ -201,7 +199,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if
library
!=
binary
:
if
library
!=
binary
:
library_path
=
binary_paths
[
library
]
library_path
=
binary_paths
[
library
]
new_library_path
=
patch_new_path
(
library_path
,
new_libraries_path
)
new_library_path
=
patch_new_path
(
library_path
,
new_libraries_path
)
print
(
"{
0} -> {1}"
.
format
(
library
,
new_library_path
)
)
print
(
f
"
{
library
}
->
{
new_library_path
}
"
)
shutil
.
copyfile
(
library_path
,
new_library_path
)
shutil
.
copyfile
(
library_path
,
new_library_path
)
new_names
[
library
]
=
new_library_path
new_names
[
library
]
=
new_library_path
...
@@ -214,7 +212,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
...
@@ -214,7 +212,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
new_library_name
=
new_names
[
library
]
new_library_name
=
new_names
[
library
]
for
dep
in
library_dependencies
:
for
dep
in
library_dependencies
:
new_dep
=
osp
.
basename
(
new_names
[
dep
])
new_dep
=
osp
.
basename
(
new_names
[
dep
])
print
(
"{
0}: {1} -> {2}"
.
format
(
library
,
dep
,
new_dep
)
)
print
(
f
"
{
library
}
:
{
dep
}
->
{
new_dep
}
"
)
subprocess
.
check_output
(
subprocess
.
check_output
(
[
patchelf
,
"--replace-needed"
,
dep
,
new_dep
,
new_library_name
],
cwd
=
new_libraries_path
[
patchelf
,
"--replace-needed"
,
dep
,
new_dep
,
new_library_name
],
cwd
=
new_libraries_path
)
)
...
@@ -228,7 +226,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
...
@@ -228,7 +226,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
library_dependencies
=
binary_dependencies
[
binary
]
library_dependencies
=
binary_dependencies
[
binary
]
for
dep
in
library_dependencies
:
for
dep
in
library_dependencies
:
new_dep
=
osp
.
basename
(
new_names
[
dep
])
new_dep
=
osp
.
basename
(
new_names
[
dep
])
print
(
"{
0}: {1} -> {2}"
.
format
(
binary
,
dep
,
new_dep
)
)
print
(
f
"
{
binary
}
:
{
dep
}
->
{
new_dep
}
"
)
subprocess
.
check_output
([
patchelf
,
"--replace-needed"
,
dep
,
new_dep
,
binary
],
cwd
=
output_library
)
subprocess
.
check_output
([
patchelf
,
"--replace-needed"
,
dep
,
new_dep
,
binary
],
cwd
=
output_library
)
print
(
"Update library rpath"
)
print
(
"Update library rpath"
)
...
@@ -244,7 +242,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
...
@@ -244,7 +242,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
Given a shared library, find the transitive closure of its dependencies,
Given a shared library, find the transitive closure of its dependencies,
rename and copy them into the wheel.
rename and copy them into the wheel.
"""
"""
print
(
"Relocating {
0}"
.
format
(
binary
)
)
print
(
f
"Relocating
{
binary
}
"
)
binary_path
=
osp
.
join
(
output_library
,
binary
)
binary_path
=
osp
.
join
(
output_library
,
binary
)
library_dlls
=
find_dll_dependencies
(
dumpbin
,
binary_path
)
library_dlls
=
find_dll_dependencies
(
dumpbin
,
binary_path
)
...
@@ -255,18 +253,18 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
...
@@ -255,18 +253,18 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
while
binary_queue
!=
[]:
while
binary_queue
!=
[]:
library
,
parent
=
binary_queue
.
pop
(
0
)
library
,
parent
=
binary_queue
.
pop
(
0
)
if
library
in
WINDOWS_ALLOWLIST
or
library
.
startswith
(
"api-ms-win"
):
if
library
in
WINDOWS_ALLOWLIST
or
library
.
startswith
(
"api-ms-win"
):
print
(
"Omitting {
0}"
.
format
(
library
)
)
print
(
f
"Omitting
{
library
}
"
)
continue
continue
library_path
=
find_program
(
library
)
library_path
=
find_program
(
library
)
if
library_path
is
None
:
if
library_path
is
None
:
print
(
"{
0
} not found"
.
format
(
library
)
)
print
(
f
"
{
library
}
not found"
)
continue
continue
if
osp
.
basename
(
osp
.
dirname
(
library_path
))
==
"system32"
:
if
osp
.
basename
(
osp
.
dirname
(
library_path
))
==
"system32"
:
continue
continue
print
(
"{
0}: {1}"
.
format
(
library
,
library_path
)
)
print
(
f
"
{
library
}
:
{
library_path
}
"
)
parent_dependencies
=
binary_dependencies
.
get
(
parent
,
[])
parent_dependencies
=
binary_dependencies
.
get
(
parent
,
[])
parent_dependencies
.
append
(
library
)
parent_dependencies
.
append
(
library
)
binary_dependencies
[
parent
]
=
parent_dependencies
binary_dependencies
[
parent
]
=
parent_dependencies
...
@@ -284,7 +282,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
...
@@ -284,7 +282,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
if
library
!=
binary
:
if
library
!=
binary
:
library_path
=
binary_paths
[
library
]
library_path
=
binary_paths
[
library
]
new_library_path
=
osp
.
join
(
package_dir
,
library
)
new_library_path
=
osp
.
join
(
package_dir
,
library
)
print
(
"{
0} -> {1}"
.
format
(
library
,
new_library_path
)
)
print
(
f
"
{
library
}
->
{
new_library_path
}
"
)
shutil
.
copyfile
(
library_path
,
new_library_path
)
shutil
.
copyfile
(
library_path
,
new_library_path
)
...
@@ -300,16 +298,16 @@ def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
...
@@ -300,16 +298,16 @@ def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
full_file
=
osp
.
join
(
root
,
this_file
)
full_file
=
osp
.
join
(
root
,
this_file
)
rel_file
=
osp
.
relpath
(
full_file
,
output_dir
)
rel_file
=
osp
.
relpath
(
full_file
,
output_dir
)
if
full_file
==
record_file
:
if
full_file
==
record_file
:
f
.
write
(
"{
0},,
\n
"
.
format
(
rel_file
)
)
f
.
write
(
f
"
{
rel_file
}
,,
\n
"
)
else
:
else
:
digest
,
size
=
rehash
(
full_file
)
digest
,
size
=
rehash
(
full_file
)
f
.
write
(
"{
0},{1},{2}
\n
"
.
format
(
rel_file
,
digest
,
size
)
)
f
.
write
(
f
"
{
rel_file
}
,
{
digest
}
,
{
size
}
\n
"
)
print
(
"Compressing wheel"
)
print
(
"Compressing wheel"
)
base_wheel_name
=
osp
.
join
(
wheel_dir
,
wheel_name
)
base_wheel_name
=
osp
.
join
(
wheel_dir
,
wheel_name
)
shutil
.
make_archive
(
base_wheel_name
,
"zip"
,
output_dir
)
shutil
.
make_archive
(
base_wheel_name
,
"zip"
,
output_dir
)
os
.
remove
(
wheel
)
os
.
remove
(
wheel
)
shutil
.
move
(
"{
0}.zip"
.
format
(
base_wheel_name
)
,
wheel
)
shutil
.
move
(
f
"
{
base_wheel_name
}
.zip"
,
wheel
)
shutil
.
rmtree
(
output_dir
)
shutil
.
rmtree
(
output_dir
)
...
@@ -317,9 +315,7 @@ def patch_linux():
...
@@ -317,9 +315,7 @@ def patch_linux():
# Get patchelf location
# Get patchelf location
patchelf
=
find_program
(
"patchelf"
)
patchelf
=
find_program
(
"patchelf"
)
if
patchelf
is
None
:
if
patchelf
is
None
:
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"Patchelf was not found in the system, please make sure that is available on the PATH."
)
"Patchelf was not found in the system, please"
" make sure that is available on the PATH."
)
# Find wheel
# Find wheel
print
(
"Finding wheels..."
)
print
(
"Finding wheels..."
)
...
@@ -338,7 +334,7 @@ def patch_linux():
...
@@ -338,7 +334,7 @@ def patch_linux():
print
(
"Unzipping wheel..."
)
print
(
"Unzipping wheel..."
)
wheel_file
=
osp
.
basename
(
wheel
)
wheel_file
=
osp
.
basename
(
wheel
)
wheel_dir
=
osp
.
dirname
(
wheel
)
wheel_dir
=
osp
.
dirname
(
wheel
)
print
(
"{
0}"
.
format
(
wheel_file
)
)
print
(
f
"
{
wheel_file
}
"
)
wheel_name
,
_
=
osp
.
splitext
(
wheel_file
)
wheel_name
,
_
=
osp
.
splitext
(
wheel_file
)
unzip_file
(
wheel
,
output_dir
)
unzip_file
(
wheel
,
output_dir
)
...
@@ -355,9 +351,7 @@ def patch_win():
...
@@ -355,9 +351,7 @@ def patch_win():
# Get dumpbin location
# Get dumpbin location
dumpbin
=
find_program
(
"dumpbin"
)
dumpbin
=
find_program
(
"dumpbin"
)
if
dumpbin
is
None
:
if
dumpbin
is
None
:
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"Dumpbin was not found in the system, please make sure that is available on the PATH."
)
"Dumpbin was not found in the system, please"
" make sure that is available on the PATH."
)
# Find wheel
# Find wheel
print
(
"Finding wheels..."
)
print
(
"Finding wheels..."
)
...
@@ -376,7 +370,7 @@ def patch_win():
...
@@ -376,7 +370,7 @@ def patch_win():
print
(
"Unzipping wheel..."
)
print
(
"Unzipping wheel..."
)
wheel_file
=
osp
.
basename
(
wheel
)
wheel_file
=
osp
.
basename
(
wheel
)
wheel_dir
=
osp
.
dirname
(
wheel
)
wheel_dir
=
osp
.
dirname
(
wheel
)
print
(
"{
0}"
.
format
(
wheel_file
)
)
print
(
f
"
{
wheel_file
}
"
)
wheel_name
,
_
=
osp
.
splitext
(
wheel_file
)
wheel_name
,
_
=
osp
.
splitext
(
wheel_file
)
unzip_file
(
wheel
,
output_dir
)
unzip_file
(
wheel
,
output_dir
)
...
...
references/classification/train.py
View file @
d367a01a
...
@@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
...
@@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value}"
))
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value}"
))
metric_logger
.
add_meter
(
"img/s"
,
utils
.
SmoothedValue
(
window_size
=
10
,
fmt
=
"{value}"
))
metric_logger
.
add_meter
(
"img/s"
,
utils
.
SmoothedValue
(
window_size
=
10
,
fmt
=
"{value}"
))
header
=
"Epoch: [{
}]"
.
format
(
epoch
)
header
=
f
"Epoch: [
{
epoch
}
]"
for
i
,
(
image
,
target
)
in
enumerate
(
metric_logger
.
log_every
(
data_loader
,
args
.
print_freq
,
header
)):
for
i
,
(
image
,
target
)
in
enumerate
(
metric_logger
.
log_every
(
data_loader
,
args
.
print_freq
,
header
)):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
image
,
target
=
image
.
to
(
device
),
target
.
to
(
device
)
image
,
target
=
image
.
to
(
device
),
target
.
to
(
device
)
...
@@ -121,7 +121,7 @@ def load_data(traindir, valdir, args):
...
@@ -121,7 +121,7 @@ def load_data(traindir, valdir, args):
cache_path
=
_get_cache_path
(
traindir
)
cache_path
=
_get_cache_path
(
traindir
)
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
# Attention, as the transforms are also cached!
# Attention, as the transforms are also cached!
print
(
"Loading dataset_train from {
}"
.
format
(
cache_path
)
)
print
(
f
"Loading dataset_train from
{
cache_path
}
"
)
dataset
,
_
=
torch
.
load
(
cache_path
)
dataset
,
_
=
torch
.
load
(
cache_path
)
else
:
else
:
auto_augment_policy
=
getattr
(
args
,
"auto_augment"
,
None
)
auto_augment_policy
=
getattr
(
args
,
"auto_augment"
,
None
)
...
@@ -136,7 +136,7 @@ def load_data(traindir, valdir, args):
...
@@ -136,7 +136,7 @@ def load_data(traindir, valdir, args):
),
),
)
)
if
args
.
cache_dataset
:
if
args
.
cache_dataset
:
print
(
"Saving dataset_train to {
}"
.
format
(
cache_path
)
)
print
(
f
"Saving dataset_train to
{
cache_path
}
"
)
utils
.
mkdir
(
os
.
path
.
dirname
(
cache_path
))
utils
.
mkdir
(
os
.
path
.
dirname
(
cache_path
))
utils
.
save_on_master
((
dataset
,
traindir
),
cache_path
)
utils
.
save_on_master
((
dataset
,
traindir
),
cache_path
)
print
(
"Took"
,
time
.
time
()
-
st
)
print
(
"Took"
,
time
.
time
()
-
st
)
...
@@ -145,7 +145,7 @@ def load_data(traindir, valdir, args):
...
@@ -145,7 +145,7 @@ def load_data(traindir, valdir, args):
cache_path
=
_get_cache_path
(
valdir
)
cache_path
=
_get_cache_path
(
valdir
)
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
# Attention, as the transforms are also cached!
# Attention, as the transforms are also cached!
print
(
"Loading dataset_test from {
}"
.
format
(
cache_path
)
)
print
(
f
"Loading dataset_test from
{
cache_path
}
"
)
dataset_test
,
_
=
torch
.
load
(
cache_path
)
dataset_test
,
_
=
torch
.
load
(
cache_path
)
else
:
else
:
if
not
args
.
weights
:
if
not
args
.
weights
:
...
@@ -162,7 +162,7 @@ def load_data(traindir, valdir, args):
...
@@ -162,7 +162,7 @@ def load_data(traindir, valdir, args):
preprocessing
,
preprocessing
,
)
)
if
args
.
cache_dataset
:
if
args
.
cache_dataset
:
print
(
"Saving dataset_test to {
}"
.
format
(
cache_path
)
)
print
(
f
"Saving dataset_test to
{
cache_path
}
"
)
utils
.
mkdir
(
os
.
path
.
dirname
(
cache_path
))
utils
.
mkdir
(
os
.
path
.
dirname
(
cache_path
))
utils
.
save_on_master
((
dataset_test
,
valdir
),
cache_path
)
utils
.
save_on_master
((
dataset_test
,
valdir
),
cache_path
)
...
@@ -270,8 +270,8 @@ def main(args):
...
@@ -270,8 +270,8 @@ def main(args):
main_lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ExponentialLR
(
optimizer
,
gamma
=
args
.
lr_gamma
)
main_lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ExponentialLR
(
optimizer
,
gamma
=
args
.
lr_gamma
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
f
"Invalid lr scheduler '
{
args
.
lr_scheduler
}
'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
.
format
(
args
.
lr_scheduler
)
"are supported."
)
)
if
args
.
lr_warmup_epochs
>
0
:
if
args
.
lr_warmup_epochs
>
0
:
...
@@ -285,7 +285,7 @@ def main(args):
...
@@ -285,7 +285,7 @@ def main(args):
)
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Invalid warmup lr method '
{
args
.
lr_warmup_method
}
'. Only linear and constant
"
"
are supported."
f
"Invalid warmup lr method '
{
args
.
lr_warmup_method
}
'. Only linear and constant are supported."
)
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
SequentialLR
(
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
SequentialLR
(
optimizer
,
schedulers
=
[
warmup_lr_scheduler
,
main_lr_scheduler
],
milestones
=
[
args
.
lr_warmup_epochs
]
optimizer
,
schedulers
=
[
warmup_lr_scheduler
,
main_lr_scheduler
],
milestones
=
[
args
.
lr_warmup_epochs
]
...
@@ -351,12 +351,12 @@ def main(args):
...
@@ -351,12 +351,12 @@ def main(args):
}
}
if
model_ema
:
if
model_ema
:
checkpoint
[
"model_ema"
]
=
model_ema
.
state_dict
()
checkpoint
[
"model_ema"
]
=
model_ema
.
state_dict
()
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
def
get_args_parser
(
add_help
=
True
):
def
get_args_parser
(
add_help
=
True
):
...
...
references/classification/train_quantization.py
View file @
d367a01a
...
@@ -20,7 +20,7 @@ def main(args):
...
@@ -20,7 +20,7 @@ def main(args):
print
(
args
)
print
(
args
)
if
args
.
post_training_quantize
and
args
.
distributed
:
if
args
.
post_training_quantize
and
args
.
distributed
:
raise
RuntimeError
(
"Post training quantization example should not be performed
"
"
on distributed mode"
)
raise
RuntimeError
(
"Post training quantization example should not be performed on distributed mode"
)
# Set backend engine to ensure that quantized model runs on the correct kernels
# Set backend engine to ensure that quantized model runs on the correct kernels
if
args
.
backend
not
in
torch
.
backends
.
quantized
.
supported_engines
:
if
args
.
backend
not
in
torch
.
backends
.
quantized
.
supported_engines
:
...
@@ -141,13 +141,13 @@ def main(args):
...
@@ -141,13 +141,13 @@ def main(args):
"epoch"
:
epoch
,
"epoch"
:
epoch
,
"args"
:
args
,
"args"
:
args
,
}
}
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
print
(
"Saving models after epoch "
,
epoch
)
print
(
"Saving models after epoch "
,
epoch
)
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
def
get_args_parser
(
add_help
=
True
):
def
get_args_parser
(
add_help
=
True
):
...
...
references/classification/transforms.py
View file @
d367a01a
...
@@ -39,13 +39,13 @@ class RandomMixup(torch.nn.Module):
...
@@ -39,13 +39,13 @@ class RandomMixup(torch.nn.Module):
Tensor: Randomly transformed batch.
Tensor: Randomly transformed batch.
"""
"""
if
batch
.
ndim
!=
4
:
if
batch
.
ndim
!=
4
:
raise
ValueError
(
"Batch ndim should be 4. Got {
}"
.
format
(
batch
.
ndim
)
)
raise
ValueError
(
f
"Batch ndim should be 4. Got
{
batch
.
ndim
}
"
)
el
if
target
.
ndim
!=
1
:
if
target
.
ndim
!=
1
:
raise
ValueError
(
"Target ndim should be 1. Got {
}"
.
format
(
target
.
ndim
)
)
raise
ValueError
(
f
"Target ndim should be 1. Got
{
target
.
ndim
}
"
)
el
if
not
batch
.
is_floating_point
():
if
not
batch
.
is_floating_point
():
raise
TypeError
(
"Batch dtype should be a float tensor. Got {
}."
.
format
(
batch
.
dtype
)
)
raise
TypeError
(
f
"Batch dtype should be a float tensor. Got
{
batch
.
dtype
}
."
)
el
if
target
.
dtype
!=
torch
.
int64
:
if
target
.
dtype
!=
torch
.
int64
:
raise
TypeError
(
"Target dtype should be torch.int64. Got {
}"
.
format
(
target
.
dtype
)
)
raise
TypeError
(
f
"Target dtype should be torch.int64. Got
{
target
.
dtype
}
"
)
if
not
self
.
inplace
:
if
not
self
.
inplace
:
batch
=
batch
.
clone
()
batch
=
batch
.
clone
()
...
@@ -115,13 +115,13 @@ class RandomCutmix(torch.nn.Module):
...
@@ -115,13 +115,13 @@ class RandomCutmix(torch.nn.Module):
Tensor: Randomly transformed batch.
Tensor: Randomly transformed batch.
"""
"""
if
batch
.
ndim
!=
4
:
if
batch
.
ndim
!=
4
:
raise
ValueError
(
"Batch ndim should be 4. Got {
}"
.
format
(
batch
.
ndim
)
)
raise
ValueError
(
f
"Batch ndim should be 4. Got
{
batch
.
ndim
}
"
)
el
if
target
.
ndim
!=
1
:
if
target
.
ndim
!=
1
:
raise
ValueError
(
"Target ndim should be 1. Got {
}"
.
format
(
target
.
ndim
)
)
raise
ValueError
(
f
"Target ndim should be 1. Got
{
target
.
ndim
}
"
)
el
if
not
batch
.
is_floating_point
():
if
not
batch
.
is_floating_point
():
raise
TypeError
(
"Batch dtype should be a float tensor. Got {
}."
.
format
(
batch
.
dtype
)
)
raise
TypeError
(
f
"Batch dtype should be a float tensor. Got
{
batch
.
dtype
}
."
)
el
if
target
.
dtype
!=
torch
.
int64
:
if
target
.
dtype
!=
torch
.
int64
:
raise
TypeError
(
"Target dtype should be torch.int64. Got {
}"
.
format
(
target
.
dtype
)
)
raise
TypeError
(
f
"Target dtype should be torch.int64. Got
{
target
.
dtype
}
"
)
if
not
self
.
inplace
:
if
not
self
.
inplace
:
batch
=
batch
.
clone
()
batch
=
batch
.
clone
()
...
...
references/classification/utils.py
View file @
d367a01a
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
class
SmoothedValue
(
object
)
:
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
window or the global series average.
"""
"""
...
@@ -65,7 +65,7 @@ class SmoothedValue(object):
...
@@ -65,7 +65,7 @@ class SmoothedValue(object):
)
)
class
MetricLogger
(
object
)
:
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
self
.
delimiter
=
delimiter
...
@@ -82,12 +82,12 @@ class MetricLogger(object):
...
@@ -82,12 +82,12 @@ class MetricLogger(object):
return
self
.
meters
[
attr
]
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{
}'"
.
format
(
type
(
self
).
__name__
,
attr
)
)
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
def
__str__
(
self
):
loss_str
=
[]
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {
}"
.
format
(
name
,
str
(
meter
)
)
)
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
def
synchronize_between_processes
(
self
):
...
@@ -152,7 +152,7 @@ class MetricLogger(object):
...
@@ -152,7 +152,7 @@ class MetricLogger(object):
end
=
time
.
time
()
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"{} Total time: {
}"
.
format
(
header
,
total_time_str
)
)
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
"
)
class
ExponentialMovingAverage
(
torch
.
optim
.
swa_utils
.
AveragedModel
):
class
ExponentialMovingAverage
(
torch
.
optim
.
swa_utils
.
AveragedModel
):
...
@@ -270,7 +270,7 @@ def init_distributed_mode(args):
...
@@ -270,7 +270,7 @@ def init_distributed_mode(args):
torch
.
cuda
.
set_device
(
args
.
gpu
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
"nccl"
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {
}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
)
,
flush
=
True
)
print
(
f
"| distributed init (rank
{
args
.
rank
}
):
{
args
.
dist_url
}
"
,
flush
=
True
)
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
)
...
@@ -307,8 +307,7 @@ def average_checkpoints(inputs):
...
@@ -307,8 +307,7 @@ def average_checkpoints(inputs):
params_keys
=
model_params_keys
params_keys
=
model_params_keys
elif
params_keys
!=
model_params_keys
:
elif
params_keys
!=
model_params_keys
:
raise
KeyError
(
raise
KeyError
(
"For checkpoint {}, expected list of params: {}, "
f
"For checkpoint
{
f
}
, expected list of params:
{
params_keys
}
, but found:
{
model_params_keys
}
"
"but found: {}"
.
format
(
f
,
params_keys
,
model_params_keys
)
)
)
for
k
in
params_keys
:
for
k
in
params_keys
:
p
=
model_params
[
k
]
p
=
model_params
[
k
]
...
...
references/detection/coco_utils.py
View file @
d367a01a
...
@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
...
@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
class
FilterAndRemapCocoCategories
(
object
)
:
class
FilterAndRemapCocoCategories
:
def
__init__
(
self
,
categories
,
remap
=
True
):
def
__init__
(
self
,
categories
,
remap
=
True
):
self
.
categories
=
categories
self
.
categories
=
categories
self
.
remap
=
remap
self
.
remap
=
remap
...
@@ -44,7 +44,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
...
@@ -44,7 +44,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return
masks
return
masks
class
ConvertCocoPolysToMask
(
object
)
:
class
ConvertCocoPolysToMask
:
def
__call__
(
self
,
image
,
target
):
def
__call__
(
self
,
image
,
target
):
w
,
h
=
image
.
size
w
,
h
=
image
.
size
...
@@ -205,11 +205,11 @@ def get_coco_api_from_dataset(dataset):
...
@@ -205,11 +205,11 @@ def get_coco_api_from_dataset(dataset):
class
CocoDetection
(
torchvision
.
datasets
.
CocoDetection
):
class
CocoDetection
(
torchvision
.
datasets
.
CocoDetection
):
def
__init__
(
self
,
img_folder
,
ann_file
,
transforms
):
def
__init__
(
self
,
img_folder
,
ann_file
,
transforms
):
super
(
CocoDetection
,
self
).
__init__
(
img_folder
,
ann_file
)
super
().
__init__
(
img_folder
,
ann_file
)
self
.
_transforms
=
transforms
self
.
_transforms
=
transforms
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
img
,
target
=
super
(
CocoDetection
,
self
).
__getitem__
(
idx
)
img
,
target
=
super
().
__getitem__
(
idx
)
image_id
=
self
.
ids
[
idx
]
image_id
=
self
.
ids
[
idx
]
target
=
dict
(
image_id
=
image_id
,
annotations
=
target
)
target
=
dict
(
image_id
=
image_id
,
annotations
=
target
)
if
self
.
_transforms
is
not
None
:
if
self
.
_transforms
is
not
None
:
...
...
references/detection/engine.py
View file @
d367a01a
...
@@ -13,7 +13,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
...
@@ -13,7 +13,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model
.
train
()
model
.
train
()
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value:.6f}"
))
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value:.6f}"
))
header
=
"Epoch: [{
}]"
.
format
(
epoch
)
header
=
f
"Epoch: [
{
epoch
}
]"
lr_scheduler
=
None
lr_scheduler
=
None
if
epoch
==
0
:
if
epoch
==
0
:
...
@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
...
@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
loss_value
=
losses_reduced
.
item
()
loss_value
=
losses_reduced
.
item
()
if
not
math
.
isfinite
(
loss_value
):
if
not
math
.
isfinite
(
loss_value
):
print
(
"Loss is {}, stopping training"
.
format
(
loss_value
)
)
print
(
f
"Loss is
{
loss_value
}
, stopping training"
)
print
(
loss_dict_reduced
)
print
(
loss_dict_reduced
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
...
...
references/detection/group_by_aspect_ratio.py
View file @
d367a01a
...
@@ -36,9 +36,7 @@ class GroupedBatchSampler(BatchSampler):
...
@@ -36,9 +36,7 @@ class GroupedBatchSampler(BatchSampler):
def
__init__
(
self
,
sampler
,
group_ids
,
batch_size
):
def
__init__
(
self
,
sampler
,
group_ids
,
batch_size
):
if
not
isinstance
(
sampler
,
Sampler
):
if
not
isinstance
(
sampler
,
Sampler
):
raise
ValueError
(
raise
ValueError
(
f
"sampler should be an instance of torch.utils.data.Sampler, but got sampler=
{
sampler
}
"
)
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.
format
(
sampler
)
)
self
.
sampler
=
sampler
self
.
sampler
=
sampler
self
.
group_ids
=
group_ids
self
.
group_ids
=
group_ids
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -193,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
...
@@ -193,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
# count number of elements per group
# count number of elements per group
counts
=
np
.
unique
(
groups
,
return_counts
=
True
)[
1
]
counts
=
np
.
unique
(
groups
,
return_counts
=
True
)[
1
]
fbins
=
[
0
]
+
bins
+
[
np
.
inf
]
fbins
=
[
0
]
+
bins
+
[
np
.
inf
]
print
(
"Using {} as bins for aspect ratio quantization"
.
format
(
fbins
)
)
print
(
f
"Using
{
fbins
}
as bins for aspect ratio quantization"
)
print
(
"Count of instances per bin: {
}"
.
format
(
counts
)
)
print
(
f
"Count of instances per bin:
{
counts
}
"
)
return
groups
return
groups
references/detection/train.py
View file @
d367a01a
...
@@ -65,7 +65,7 @@ def get_args_parser(add_help=True):
...
@@ -65,7 +65,7 @@ def get_args_parser(add_help=True):
"--lr"
,
"--lr"
,
default
=
0.02
,
default
=
0.02
,
type
=
float
,
type
=
float
,
help
=
"initial learning rate, 0.02 is the default value for training
"
"
on 8 gpus and 2 images_per_gpu"
,
help
=
"initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu"
,
)
)
parser
.
add_argument
(
"--momentum"
,
default
=
0.9
,
type
=
float
,
metavar
=
"M"
,
help
=
"momentum"
)
parser
.
add_argument
(
"--momentum"
,
default
=
0.9
,
type
=
float
,
metavar
=
"M"
,
help
=
"momentum"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -197,8 +197,7 @@ def main(args):
...
@@ -197,8 +197,7 @@ def main(args):
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
args
.
epochs
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
args
.
epochs
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
f
"Invalid lr scheduler '
{
args
.
lr_scheduler
}
'. Only MultiStepLR and CosineAnnealingLR are supported."
"are supported."
.
format
(
args
.
lr_scheduler
)
)
)
if
args
.
resume
:
if
args
.
resume
:
...
@@ -227,7 +226,7 @@ def main(args):
...
@@ -227,7 +226,7 @@ def main(args):
"args"
:
args
,
"args"
:
args
,
"epoch"
:
epoch
,
"epoch"
:
epoch
,
}
}
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
# evaluate after every epoch
# evaluate after every epoch
...
@@ -235,7 +234,7 @@ def main(args):
...
@@ -235,7 +234,7 @@ def main(args):
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
references/detection/transforms.py
View file @
d367a01a
...
@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
...
@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
return
flipped_data
return
flipped_data
class
Compose
(
object
)
:
class
Compose
:
def
__init__
(
self
,
transforms
):
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
self
.
transforms
=
transforms
...
@@ -103,7 +103,7 @@ class RandomIoUCrop(nn.Module):
...
@@ -103,7 +103,7 @@ class RandomIoUCrop(nn.Module):
if
isinstance
(
image
,
torch
.
Tensor
):
if
isinstance
(
image
,
torch
.
Tensor
):
if
image
.
ndimension
()
not
in
{
2
,
3
}:
if
image
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"image should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
image
.
n
dimension
())
)
raise
ValueError
(
f
"image should be 2/3 dimensional. Got
{
image
.
n
dimension
()
}
dimension
s."
)
elif
image
.
ndimension
()
==
2
:
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
image
=
image
.
unsqueeze
(
0
)
...
@@ -171,7 +171,7 @@ class RandomZoomOut(nn.Module):
...
@@ -171,7 +171,7 @@ class RandomZoomOut(nn.Module):
self
.
fill
=
fill
self
.
fill
=
fill
self
.
side_range
=
side_range
self
.
side_range
=
side_range
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
if
side_range
[
0
]
<
1.0
or
side_range
[
0
]
>
side_range
[
1
]:
raise
ValueError
(
"Invalid canvas side range provided {
}."
.
format
(
side_range
)
)
raise
ValueError
(
f
"Invalid canvas side range provided
{
side_range
}
."
)
self
.
p
=
p
self
.
p
=
p
@
torch
.
jit
.
unused
@
torch
.
jit
.
unused
...
@@ -185,7 +185,7 @@ class RandomZoomOut(nn.Module):
...
@@ -185,7 +185,7 @@ class RandomZoomOut(nn.Module):
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
if
isinstance
(
image
,
torch
.
Tensor
):
if
isinstance
(
image
,
torch
.
Tensor
):
if
image
.
ndimension
()
not
in
{
2
,
3
}:
if
image
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"image should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
image
.
n
dimension
())
)
raise
ValueError
(
f
"image should be 2/3 dimensional. Got
{
image
.
n
dimension
()
}
dimension
s."
)
elif
image
.
ndimension
()
==
2
:
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
image
=
image
.
unsqueeze
(
0
)
...
@@ -244,7 +244,7 @@ class RandomPhotometricDistort(nn.Module):
...
@@ -244,7 +244,7 @@ class RandomPhotometricDistort(nn.Module):
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
if
isinstance
(
image
,
torch
.
Tensor
):
if
isinstance
(
image
,
torch
.
Tensor
):
if
image
.
ndimension
()
not
in
{
2
,
3
}:
if
image
.
ndimension
()
not
in
{
2
,
3
}:
raise
ValueError
(
"image should be 2/3 dimensional. Got {
}
dimension
s."
.
format
(
image
.
n
dimension
())
)
raise
ValueError
(
f
"image should be 2/3 dimensional. Got
{
image
.
n
dimension
()
}
dimension
s."
)
elif
image
.
ndimension
()
==
2
:
elif
image
.
ndimension
()
==
2
:
image
=
image
.
unsqueeze
(
0
)
image
=
image
.
unsqueeze
(
0
)
...
...
references/detection/utils.py
View file @
d367a01a
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
class
SmoothedValue
(
object
)
:
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
window or the global series average.
"""
"""
...
@@ -110,7 +110,7 @@ def reduce_dict(input_dict, average=True):
...
@@ -110,7 +110,7 @@ def reduce_dict(input_dict, average=True):
return
reduced_dict
return
reduced_dict
class
MetricLogger
(
object
)
:
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
self
.
delimiter
=
delimiter
...
@@ -127,12 +127,12 @@ class MetricLogger(object):
...
@@ -127,12 +127,12 @@ class MetricLogger(object):
return
self
.
meters
[
attr
]
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{
}'"
.
format
(
type
(
self
).
__name__
,
attr
)
)
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
def
__str__
(
self
):
loss_str
=
[]
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {
}"
.
format
(
name
,
str
(
meter
)
)
)
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
def
synchronize_between_processes
(
self
):
...
@@ -197,7 +197,7 @@ class MetricLogger(object):
...
@@ -197,7 +197,7 @@ class MetricLogger(object):
end
=
time
.
time
()
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"{} Total time: {
} ({:.4f} s / it)"
.
format
(
header
,
total_time_str
,
total_time
/
len
(
iterable
)
)
)
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
(
{
total_time
/
len
(
iterable
)
:.
4
f
}
s / it)"
)
def
collate_fn
(
batch
):
def
collate_fn
(
batch
):
...
@@ -274,7 +274,7 @@ def init_distributed_mode(args):
...
@@ -274,7 +274,7 @@ def init_distributed_mode(args):
torch
.
cuda
.
set_device
(
args
.
gpu
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
"nccl"
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {
}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
)
,
flush
=
True
)
print
(
f
"| distributed init (rank
{
args
.
rank
}
):
{
args
.
dist_url
}
"
,
flush
=
True
)
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
)
...
...
references/segmentation/coco_utils.py
View file @
d367a01a
...
@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
...
@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from
transforms
import
Compose
from
transforms
import
Compose
class
FilterAndRemapCocoCategories
(
object
)
:
class
FilterAndRemapCocoCategories
:
def
__init__
(
self
,
categories
,
remap
=
True
):
def
__init__
(
self
,
categories
,
remap
=
True
):
self
.
categories
=
categories
self
.
categories
=
categories
self
.
remap
=
remap
self
.
remap
=
remap
...
@@ -41,7 +41,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
...
@@ -41,7 +41,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return
masks
return
masks
class
ConvertCocoPolysToMask
(
object
)
:
class
ConvertCocoPolysToMask
:
def
__call__
(
self
,
image
,
anno
):
def
__call__
(
self
,
image
,
anno
):
w
,
h
=
image
.
size
w
,
h
=
image
.
size
segmentations
=
[
obj
[
"segmentation"
]
for
obj
in
anno
]
segmentations
=
[
obj
[
"segmentation"
]
for
obj
in
anno
]
...
...
references/segmentation/train.py
View file @
d367a01a
...
@@ -66,7 +66,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
...
@@ -66,7 +66,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
model
.
train
()
model
.
train
()
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value}"
))
metric_logger
.
add_meter
(
"lr"
,
utils
.
SmoothedValue
(
window_size
=
1
,
fmt
=
"{value}"
))
header
=
"Epoch: [{
}]"
.
format
(
epoch
)
header
=
f
"Epoch: [
{
epoch
}
]"
for
image
,
target
in
metric_logger
.
log_every
(
data_loader
,
print_freq
,
header
):
for
image
,
target
in
metric_logger
.
log_every
(
data_loader
,
print_freq
,
header
):
image
,
target
=
image
.
to
(
device
),
target
.
to
(
device
)
image
,
target
=
image
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
image
)
output
=
model
(
image
)
...
@@ -152,8 +152,7 @@ def main(args):
...
@@ -152,8 +152,7 @@ def main(args):
)
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Invalid warmup lr method '{}'. Only linear and constant "
f
"Invalid warmup lr method '
{
args
.
lr_warmup_method
}
'. Only linear and constant are supported."
"are supported."
.
format
(
args
.
lr_warmup_method
)
)
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
SequentialLR
(
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
SequentialLR
(
optimizer
,
schedulers
=
[
warmup_lr_scheduler
,
main_lr_scheduler
],
milestones
=
[
warmup_iters
]
optimizer
,
schedulers
=
[
warmup_lr_scheduler
,
main_lr_scheduler
],
milestones
=
[
warmup_iters
]
...
@@ -188,12 +187,12 @@ def main(args):
...
@@ -188,12 +187,12 @@ def main(args):
"epoch"
:
epoch
,
"epoch"
:
epoch
,
"args"
:
args
,
"args"
:
args
,
}
}
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"model_{}.pth"
.
format
(
epoch
)
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
f
"model_
{
epoch
}
.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
utils
.
save_on_master
(
checkpoint
,
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
))
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"Training time {
}"
.
format
(
total_time_str
)
)
print
(
f
"Training time
{
total_time_str
}
"
)
def
get_args_parser
(
add_help
=
True
):
def
get_args_parser
(
add_help
=
True
):
...
...
references/segmentation/transforms.py
View file @
d367a01a
...
@@ -16,7 +16,7 @@ def pad_if_smaller(img, size, fill=0):
...
@@ -16,7 +16,7 @@ def pad_if_smaller(img, size, fill=0):
return
img
return
img
class
Compose
(
object
)
:
class
Compose
:
def
__init__
(
self
,
transforms
):
def
__init__
(
self
,
transforms
):
self
.
transforms
=
transforms
self
.
transforms
=
transforms
...
@@ -26,7 +26,7 @@ class Compose(object):
...
@@ -26,7 +26,7 @@ class Compose(object):
return
image
,
target
return
image
,
target
class
RandomResize
(
object
)
:
class
RandomResize
:
def
__init__
(
self
,
min_size
,
max_size
=
None
):
def
__init__
(
self
,
min_size
,
max_size
=
None
):
self
.
min_size
=
min_size
self
.
min_size
=
min_size
if
max_size
is
None
:
if
max_size
is
None
:
...
@@ -40,7 +40,7 @@ class RandomResize(object):
...
@@ -40,7 +40,7 @@ class RandomResize(object):
return
image
,
target
return
image
,
target
class
RandomHorizontalFlip
(
object
)
:
class
RandomHorizontalFlip
:
def
__init__
(
self
,
flip_prob
):
def
__init__
(
self
,
flip_prob
):
self
.
flip_prob
=
flip_prob
self
.
flip_prob
=
flip_prob
...
@@ -51,7 +51,7 @@ class RandomHorizontalFlip(object):
...
@@ -51,7 +51,7 @@ class RandomHorizontalFlip(object):
return
image
,
target
return
image
,
target
class
RandomCrop
(
object
)
:
class
RandomCrop
:
def
__init__
(
self
,
size
):
def
__init__
(
self
,
size
):
self
.
size
=
size
self
.
size
=
size
...
@@ -64,7 +64,7 @@ class RandomCrop(object):
...
@@ -64,7 +64,7 @@ class RandomCrop(object):
return
image
,
target
return
image
,
target
class
CenterCrop
(
object
)
:
class
CenterCrop
:
def
__init__
(
self
,
size
):
def
__init__
(
self
,
size
):
self
.
size
=
size
self
.
size
=
size
...
@@ -90,7 +90,7 @@ class ConvertImageDtype:
...
@@ -90,7 +90,7 @@ class ConvertImageDtype:
return
image
,
target
return
image
,
target
class
Normalize
(
object
)
:
class
Normalize
:
def
__init__
(
self
,
mean
,
std
):
def
__init__
(
self
,
mean
,
std
):
self
.
mean
=
mean
self
.
mean
=
mean
self
.
std
=
std
self
.
std
=
std
...
...
references/segmentation/utils.py
View file @
d367a01a
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
class
SmoothedValue
(
object
)
:
class
SmoothedValue
:
"""Track a series of values and provide access to smoothed values over a
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
window or the global series average.
"""
"""
...
@@ -67,7 +67,7 @@ class SmoothedValue(object):
...
@@ -67,7 +67,7 @@ class SmoothedValue(object):
)
)
class
ConfusionMatrix
(
object
)
:
class
ConfusionMatrix
:
def
__init__
(
self
,
num_classes
):
def
__init__
(
self
,
num_classes
):
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
mat
=
None
self
.
mat
=
None
...
@@ -101,15 +101,15 @@ class ConfusionMatrix(object):
...
@@ -101,15 +101,15 @@ class ConfusionMatrix(object):
def
__str__
(
self
):
def
__str__
(
self
):
acc_global
,
acc
,
iu
=
self
.
compute
()
acc_global
,
acc
,
iu
=
self
.
compute
()
return
(
"global correct: {:.1f}
\n
"
"
average row correct: {}
\n
"
"
IoU: {}
\n
"
"
mean IoU: {:.1f}"
).
format
(
return
(
"global correct: {:.1f}
\n
average row correct: {}
\n
IoU: {}
\n
mean IoU: {:.1f}"
).
format
(
acc_global
.
item
()
*
100
,
acc_global
.
item
()
*
100
,
[
"{:.1f}"
.
format
(
i
)
for
i
in
(
acc
*
100
).
tolist
()],
[
f
"
{
i
:.
1
f
}
"
for
i
in
(
acc
*
100
).
tolist
()],
[
"{:.1f}"
.
format
(
i
)
for
i
in
(
iu
*
100
).
tolist
()],
[
f
"
{
i
:.
1
f
}
"
for
i
in
(
iu
*
100
).
tolist
()],
iu
.
mean
().
item
()
*
100
,
iu
.
mean
().
item
()
*
100
,
)
)
class
MetricLogger
(
object
)
:
class
MetricLogger
:
def
__init__
(
self
,
delimiter
=
"
\t
"
):
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
self
.
delimiter
=
delimiter
...
@@ -126,12 +126,12 @@ class MetricLogger(object):
...
@@ -126,12 +126,12 @@ class MetricLogger(object):
return
self
.
meters
[
attr
]
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{
}'"
.
format
(
type
(
self
).
__name__
,
attr
)
)
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
attr
}
'"
)
def
__str__
(
self
):
def
__str__
(
self
):
loss_str
=
[]
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {
}"
.
format
(
name
,
str
(
meter
)
)
)
loss_str
.
append
(
f
"
{
name
}
:
{
str
(
meter
)
}
"
)
return
self
.
delimiter
.
join
(
loss_str
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
def
synchronize_between_processes
(
self
):
...
@@ -196,7 +196,7 @@ class MetricLogger(object):
...
@@ -196,7 +196,7 @@ class MetricLogger(object):
end
=
time
.
time
()
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
"{} Total time: {
}"
.
format
(
header
,
total_time_str
)
)
print
(
f
"
{
header
}
Total time:
{
total_time_str
}
"
)
def
cat_list
(
images
,
fill_value
=
0
):
def
cat_list
(
images
,
fill_value
=
0
):
...
@@ -287,7 +287,7 @@ def init_distributed_mode(args):
...
@@ -287,7 +287,7 @@ def init_distributed_mode(args):
torch
.
cuda
.
set_device
(
args
.
gpu
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
args
.
dist_backend
=
"nccl"
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {
}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
)
,
flush
=
True
)
print
(
f
"| distributed init (rank
{
args
.
rank
}
):
{
args
.
dist_url
}
"
,
flush
=
True
)
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
)
...
...
Prev
1
2
3
4
5
…
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment