Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
......@@ -201,7 +201,7 @@ jobs:
pip install --user --progress-bar off types-requests
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off git+https://github.com/pytorch/data.git
pip install --user --progress-bar off --editable .
pip install --user --progress-bar off --no-build-isolation --editable .
mypy --config-file mypy.ini
docstring_parameters_sync:
......@@ -235,7 +235,7 @@ jobs:
command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off .
pip install --user --progress-bar off --no-build-isolation .
pip install pytest
python test/test_hub.py
......@@ -248,7 +248,7 @@ jobs:
command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off .
pip install --user --progress-bar off --no-build-isolation .
pip install --user onnx
pip install --user onnxruntime
pip install --user pytest
......
......@@ -201,7 +201,7 @@ jobs:
pip install --user --progress-bar off types-requests
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off git+https://github.com/pytorch/data.git
pip install --user --progress-bar off --editable .
pip install --user --progress-bar off --no-build-isolation --editable .
mypy --config-file mypy.ini
docstring_parameters_sync:
......@@ -235,7 +235,7 @@ jobs:
command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off .
pip install --user --progress-bar off --no-build-isolation .
pip install pytest
python test/test_hub.py
......@@ -248,7 +248,7 @@ jobs:
command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off .
pip install --user --progress-bar off --no-build-isolation .
pip install --user onnx
pip install --user onnxruntime
pip install --user pytest
......
......@@ -14,10 +14,11 @@ See this comment for design rationale:
https://github.com/pytorch/vision/pull/1321#issuecomment-531033978
"""
import os.path
import jinja2
from jinja2 import select_autoescape
import yaml
import os.path
from jinja2 import select_autoescape
PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
......@@ -25,57 +26,66 @@ PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
RC_PATTERN = r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
def build_workflows(prefix='', filter_branch=None, upload=False, indentation=6, windows_latest_only=False):
def build_workflows(prefix="", filter_branch=None, upload=False, indentation=6, windows_latest_only=False):
w = []
for btype in ["wheel", "conda"]:
for os_type in ["linux", "macos", "win"]:
python_versions = PYTHON_VERSIONS
cu_versions_dict = {"linux": ["cpu", "cu102", "cu111", "cu113", "rocm4.1", "rocm4.2"],
"win": ["cpu", "cu102", "cu111", "cu113"],
"macos": ["cpu"]}
cu_versions_dict = {
"linux": ["cpu", "cu102", "cu111", "cu113", "rocm4.1", "rocm4.2"],
"win": ["cpu", "cu102", "cu111", "cu113"],
"macos": ["cpu"],
}
cu_versions = cu_versions_dict[os_type]
for python_version in python_versions:
for cu_version in cu_versions:
# ROCm conda packages not yet supported
if cu_version.startswith('rocm') and btype == "conda":
if cu_version.startswith("rocm") and btype == "conda":
continue
for unicode in [False]:
fb = filter_branch
if windows_latest_only and os_type == "win" and filter_branch is None and \
(python_version != python_versions[-1] or
(cu_version not in [cu_versions[0], cu_versions[-1]])):
if (
windows_latest_only
and os_type == "win"
and filter_branch is None
and (
python_version != python_versions[-1]
or (cu_version not in [cu_versions[0], cu_versions[-1]])
)
):
fb = "main"
if not fb and (os_type == 'linux' and
cu_version == 'cpu' and
btype == 'wheel' and
python_version == '3.7'):
if not fb and (
os_type == "linux" and cu_version == "cpu" and btype == "wheel" and python_version == "3.7"
):
# the fields must match the build_docs "requires" dependency
fb = "/.*/"
w += workflow_pair(
btype, os_type, python_version, cu_version,
unicode, prefix, upload, filter_branch=fb)
btype, os_type, python_version, cu_version, unicode, prefix, upload, filter_branch=fb
)
if not filter_branch:
# Build on every pull request, but upload only on nightly and tags
w += build_doc_job('/.*/')
w += upload_doc_job('nightly')
w += build_doc_job("/.*/")
w += upload_doc_job("nightly")
return indent(indentation, w)
def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix='', upload=False, *, filter_branch=None):
def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix="", upload=False, *, filter_branch=None):
w = []
unicode_suffix = "u" if unicode else ""
base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}{unicode_suffix}_{cu_version}"
w.append(generate_base_workflow(
base_workflow_name, python_version, cu_version,
unicode, os_type, btype, filter_branch=filter_branch))
w.append(
generate_base_workflow(
base_workflow_name, python_version, cu_version, unicode, os_type, btype, filter_branch=filter_branch
)
)
if upload:
w.append(generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, filter_branch=filter_branch))
if filter_branch == 'nightly' and os_type in ['linux', 'win']:
pydistro = 'pip' if btype == 'wheel' else 'conda'
if filter_branch == "nightly" and os_type in ["linux", "win"]:
pydistro = "pip" if btype == "wheel" else "conda"
w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, os_type))
return w
......@@ -85,12 +95,13 @@ def build_doc_job(filter_branch):
job = {
"name": "build_docs",
"python_version": "3.7",
"requires": ["binary_linux_wheel_py3.7_cpu", ],
"requires": [
"binary_linux_wheel_py3.7_cpu",
],
}
if filter_branch:
job["filters"] = gen_filter_branch_tree(filter_branch,
tags_list=RC_PATTERN)
job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN)
return [{"build_docs": job}]
......@@ -99,12 +110,13 @@ def upload_doc_job(filter_branch):
"name": "upload_docs",
"context": "org-member",
"python_version": "3.7",
"requires": ["build_docs", ],
"requires": [
"build_docs",
],
}
if filter_branch:
job["filters"] = gen_filter_branch_tree(filter_branch,
tags_list=RC_PATTERN)
job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN)
return [{"upload_docs": job}]
......@@ -122,24 +134,25 @@ manylinux_images = {
def get_manylinux_image(cu_version):
if cu_version == "cpu":
return "pytorch/manylinux-cuda102"
elif cu_version.startswith('cu'):
cu_suffix = cu_version[len('cu'):]
elif cu_version.startswith("cu"):
cu_suffix = cu_version[len("cu") :]
return f"pytorch/manylinux-cuda{cu_suffix}"
elif cu_version.startswith('rocm'):
rocm_suffix = cu_version[len('rocm'):]
elif cu_version.startswith("rocm"):
rocm_suffix = cu_version[len("rocm") :]
return f"pytorch/manylinux-rocm:{rocm_suffix}"
def get_conda_image(cu_version):
if cu_version == "cpu":
return "pytorch/conda-builder:cpu"
elif cu_version.startswith('cu'):
cu_suffix = cu_version[len('cu'):]
elif cu_version.startswith("cu"):
cu_suffix = cu_version[len("cu") :]
return f"pytorch/conda-builder:cuda{cu_suffix}"
def generate_base_workflow(base_workflow_name, python_version, cu_version,
unicode, os_type, btype, *, filter_branch=None):
def generate_base_workflow(
base_workflow_name, python_version, cu_version, unicode, os_type, btype, *, filter_branch=None
):
d = {
"name": base_workflow_name,
......@@ -148,7 +161,7 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version,
}
if os_type != "win" and unicode:
d["unicode_abi"] = '1'
d["unicode_abi"] = "1"
if os_type != "win":
d["wheel_docker_image"] = get_manylinux_image(cu_version)
......@@ -158,14 +171,12 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version,
if filter_branch is not None:
d["filters"] = {
"branches": {
"only": filter_branch
},
"branches": {"only": filter_branch},
"tags": {
# Using a raw string here to avoid having to escape
# anything
"only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
}
},
}
w = f"binary_{os_type}_{btype}"
......@@ -186,19 +197,17 @@ def generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, *,
"requires": [base_workflow_name],
}
if btype == 'wheel':
d["subfolder"] = "" if os_type == 'macos' else cu_version + "/"
if btype == "wheel":
d["subfolder"] = "" if os_type == "macos" else cu_version + "/"
if filter_branch is not None:
d["filters"] = {
"branches": {
"only": filter_branch
},
"branches": {"only": filter_branch},
"tags": {
# Using a raw string here to avoid having to escape
# anything
"only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
}
},
}
return {f"binary_{btype}_upload": d}
......@@ -223,8 +232,7 @@ def generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, pyt
def indent(indentation, data_list):
return ("\n" + " " * indentation).join(
yaml.dump(data_list, default_flow_style=False).splitlines())
return ("\n" + " " * indentation).join(yaml.dump(data_list, default_flow_style=False).splitlines())
def unittest_workflows(indentation=6):
......@@ -239,12 +247,12 @@ def unittest_workflows(indentation=6):
"python_version": python_version,
}
if device_type == 'gpu':
if device_type == "gpu":
if python_version != "3.8":
job['filters'] = gen_filter_branch_tree('main', 'nightly')
job['cu_version'] = 'cu102'
job["filters"] = gen_filter_branch_tree("main", "nightly")
job["cu_version"] = "cu102"
else:
job['cu_version'] = 'cpu'
job["cu_version"] = "cpu"
jobs.append({f"unittest_{os_type}_{device_type}": job})
......@@ -253,20 +261,17 @@ def unittest_workflows(indentation=6):
def cmake_workflows(indentation=6):
jobs = []
python_version = '3.8'
for os_type in ['linux', 'windows', 'macos']:
python_version = "3.8"
for os_type in ["linux", "windows", "macos"]:
# Skip OSX CUDA
device_types = ['cpu', 'gpu'] if os_type != 'macos' else ['cpu']
device_types = ["cpu", "gpu"] if os_type != "macos" else ["cpu"]
for device in device_types:
job = {
'name': f'cmake_{os_type}_{device}',
'python_version': python_version
}
job = {"name": f"cmake_{os_type}_{device}", "python_version": python_version}
job['cu_version'] = 'cu102' if device == 'gpu' else 'cpu'
if device == 'gpu' and os_type == 'linux':
job['wheel_docker_image'] = 'pytorch/manylinux-cuda102'
jobs.append({f'cmake_{os_type}_{device}': job})
job["cu_version"] = "cu102" if device == "gpu" else "cpu"
if device == "gpu" and os_type == "linux":
job["wheel_docker_image"] = "pytorch/manylinux-cuda102"
jobs.append({f"cmake_{os_type}_{device}": job})
return indent(indentation, jobs)
......@@ -275,27 +280,27 @@ def ios_workflows(indentation=6, nightly=False):
build_job_names = []
name_prefix = "nightly_" if nightly else ""
env_prefix = "nightly-" if nightly else ""
for arch, platform in [('x86_64', 'SIMULATOR'), ('arm64', 'OS')]:
name = f'{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}'
for arch, platform in [("x86_64", "SIMULATOR"), ("arm64", "OS")]:
name = f"{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}"
build_job_names.append(name)
build_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}',
'ios_arch': arch,
'ios_platform': platform,
'name': name,
"build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}",
"ios_arch": arch,
"ios_platform": platform,
"name": name,
}
if nightly:
build_job['filters'] = gen_filter_branch_tree('nightly')
jobs.append({'binary_ios_build': build_job})
build_job["filters"] = gen_filter_branch_tree("nightly")
jobs.append({"binary_ios_build": build_job})
if nightly:
upload_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload',
'context': 'org-member',
'filters': gen_filter_branch_tree('nightly'),
'requires': build_job_names,
"build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload",
"context": "org-member",
"filters": gen_filter_branch_tree("nightly"),
"requires": build_job_names,
}
jobs.append({'binary_ios_upload': upload_job})
jobs.append({"binary_ios_upload": upload_job})
return indent(indentation, jobs)
......@@ -305,23 +310,23 @@ def android_workflows(indentation=6, nightly=False):
name_prefix = "nightly_" if nightly else ""
env_prefix = "nightly-" if nightly else ""
name = f'{name_prefix}binary_libtorchvision_ops_android'
name = f"{name_prefix}binary_libtorchvision_ops_android"
build_job_names.append(name)
build_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-android',
'name': name,
"build_environment": f"{env_prefix}binary-libtorchvision_ops-android",
"name": name,
}
if nightly:
upload_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-android-upload',
'context': 'org-member',
'filters': gen_filter_branch_tree('nightly'),
'name': f'{name_prefix}binary_libtorchvision_ops_android_upload'
"build_environment": f"{env_prefix}binary-libtorchvision_ops-android-upload",
"context": "org-member",
"filters": gen_filter_branch_tree("nightly"),
"name": f"{name_prefix}binary_libtorchvision_ops_android_upload",
}
jobs.append({'binary_android_upload': upload_job})
jobs.append({"binary_android_upload": upload_job})
else:
jobs.append({'binary_android_build': build_job})
jobs.append({"binary_android_build": build_job})
return indent(indentation, jobs)
......@@ -330,15 +335,17 @@ if __name__ == "__main__":
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(d),
lstrip_blocks=True,
autoescape=select_autoescape(enabled_extensions=('html', 'xml')),
autoescape=select_autoescape(enabled_extensions=("html", "xml")),
keep_trailing_newline=True,
)
with open(os.path.join(d, 'config.yml'), 'w') as f:
f.write(env.get_template('config.yml.in').render(
build_workflows=build_workflows,
unittest_workflows=unittest_workflows,
cmake_workflows=cmake_workflows,
ios_workflows=ios_workflows,
android_workflows=android_workflows,
))
with open(os.path.join(d, "config.yml"), "w") as f:
f.write(
env.get_template("config.yml.in").render(
build_workflows=build_workflows,
unittest_workflows=unittest_workflows,
cmake_workflows=cmake_workflows,
ios_workflows=ios_workflows,
android_workflows=android_workflows,
)
)
......@@ -42,7 +42,6 @@ import signal
import subprocess
import sys
import traceback
from functools import partial
try:
......@@ -51,7 +50,7 @@ except ImportError:
DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu'
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
class ExitStatus:
......@@ -75,14 +74,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
# os.walk() supports trimming down the dnames list
# by modifying it in-place,
# to avoid unnecessary directory listings.
dnames[:] = [
x for x in dnames
if
not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
]
fpaths = [
x for x in fpaths if not fnmatch.fnmatch(x, pattern)
]
dnames[:] = [x for x in dnames if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)]
fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)]
for f in fpaths:
ext = os.path.splitext(f)[1][1:]
if ext in extensions:
......@@ -95,11 +88,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def make_diff(file, original, reformatted):
return list(
difflib.unified_diff(
original,
reformatted,
fromfile='{}\t(original)'.format(file),
tofile='{}\t(reformatted)'.format(file),
n=3))
original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3
)
)
class DiffError(Exception):
......@@ -122,13 +113,12 @@ def run_clang_format_diff_wrapper(args, file):
except DiffError:
raise
except Exception as e:
raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__,
e), e)
raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e)
def run_clang_format_diff(args, file):
try:
with io.open(file, 'r', encoding='utf-8') as f:
with io.open(file, "r", encoding="utf-8") as f:
original = f.readlines()
except IOError as exc:
raise DiffError(str(exc))
......@@ -153,17 +143,10 @@ def run_clang_format_diff(args, file):
try:
proc = subprocess.Popen(
invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
encoding='utf-8')
except OSError as exc:
raise DiffError(
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(invocation), exc
)
invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8"
)
except OSError as exc:
raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc))
proc_stdout = proc.stdout
proc_stderr = proc.stderr
......@@ -182,30 +165,30 @@ def run_clang_format_diff(args, file):
def bold_red(s):
return '\x1b[1m\x1b[31m' + s + '\x1b[0m'
return "\x1b[1m\x1b[31m" + s + "\x1b[0m"
def colorize(diff_lines):
def bold(s):
return '\x1b[1m' + s + '\x1b[0m'
return "\x1b[1m" + s + "\x1b[0m"
def cyan(s):
return '\x1b[36m' + s + '\x1b[0m'
return "\x1b[36m" + s + "\x1b[0m"
def green(s):
return '\x1b[32m' + s + '\x1b[0m'
return "\x1b[32m" + s + "\x1b[0m"
def red(s):
return '\x1b[31m' + s + '\x1b[0m'
return "\x1b[31m" + s + "\x1b[0m"
for line in diff_lines:
if line[:4] in ['--- ', '+++ ']:
if line[:4] in ["--- ", "+++ "]:
yield bold(line)
elif line.startswith('@@ '):
elif line.startswith("@@ "):
yield cyan(line)
elif line.startswith('+'):
elif line.startswith("+"):
yield green(line)
elif line.startswith('-'):
elif line.startswith("-"):
yield red(line)
else:
yield line
......@@ -218,7 +201,7 @@ def print_diff(diff_lines, use_color):
def print_trouble(prog, message, use_colors):
error_text = 'error:'
error_text = "error:"
if use_colors:
error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
......@@ -227,45 +210,37 @@ def print_trouble(prog, message, use_colors):
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--clang-format-executable',
metavar='EXECUTABLE',
help='path to the clang-format executable',
default='clang-format')
parser.add_argument(
'--extensions',
help='comma separated list of file extensions (default: {})'.format(
DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS)
"--clang-format-executable",
metavar="EXECUTABLE",
help="path to the clang-format executable",
default="clang-format",
)
parser.add_argument(
'-r',
'--recursive',
action='store_true',
help='run recursively over directories')
parser.add_argument('files', metavar='file', nargs='+')
"--extensions",
help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS,
)
parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories")
parser.add_argument("files", metavar="file", nargs="+")
parser.add_argument("-q", "--quiet", action="store_true")
parser.add_argument(
'-q',
'--quiet',
action='store_true')
parser.add_argument(
'-j',
metavar='N',
"-j",
metavar="N",
type=int,
default=0,
help='run N clang-format jobs in parallel'
' (default number of cpus + 1)')
help="run N clang-format jobs in parallel" " (default number of cpus + 1)",
)
parser.add_argument(
'--color',
default='auto',
choices=['auto', 'always', 'never'],
help='show colored diff (default: auto)')
"--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)"
)
parser.add_argument(
'-e',
'--exclude',
metavar='PATTERN',
action='append',
"-e",
"--exclude",
metavar="PATTERN",
action="append",
default=[],
help='exclude paths matching the given glob-like pattern(s)'
' from recursive search')
help="exclude paths matching the given glob-like pattern(s)" " from recursive search",
)
args = parser.parse_args()
......@@ -282,10 +257,10 @@ def main():
colored_stdout = False
colored_stderr = False
if args.color == 'always':
if args.color == "always":
colored_stdout = True
colored_stderr = True
elif args.color == 'auto':
elif args.color == "auto":
colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty()
......@@ -298,19 +273,15 @@ def main():
except OSError as e:
print_trouble(
parser.prog,
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(version_invocation), e
),
"Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e),
use_colors=colored_stderr,
)
return ExitStatus.TROUBLE
retcode = ExitStatus.SUCCESS
files = list_files(
args.files,
recursive=args.recursive,
exclude=args.exclude,
extensions=args.extensions.split(','))
args.files, recursive=args.recursive, exclude=args.exclude, extensions=args.extensions.split(",")
)
if not files:
return
......@@ -327,8 +298,7 @@ def main():
pool = None
else:
pool = multiprocessing.Pool(njobs)
it = pool.imap_unordered(
partial(run_clang_format_diff_wrapper, args), files)
it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files)
while True:
try:
outs, errs = next(it)
......@@ -359,5 +329,5 @@ def main():
return retcode
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())
......@@ -20,8 +20,8 @@ jobs:
with:
python-version: 3.6
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Upgrade system packages
run: python -m pip install --upgrade pip setuptools wheel
- name: Checkout repository
uses: actions/checkout@v2
......@@ -30,7 +30,7 @@ jobs:
run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install torchvision
run: pip install -e .
run: pip install --no-build-isolation --editable .
- name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests
......
repos:
- repo: https://github.com/omnilib/ufmt
rev: v1.3.0
hooks:
- id: ufmt
additional_dependencies:
- black == 21.9b0
- usort == 0.6.4
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
......
......@@ -5,11 +5,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
print(torch.__version__)
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=True,
box_score_thresh=0.7,
rpn_post_nms_top_n_test=100,
rpn_score_thresh=0.4,
rpn_pre_nms_top_n_test=150)
pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
)
model.eval()
script_model = torch.jit.script(model)
......
......@@ -21,8 +21,8 @@
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import torchvision
import pytorch_sphinx_theme
import torchvision
# -- General configuration ------------------------------------------------
......@@ -33,24 +33,24 @@ import pytorch_sphinx_theme
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.duration',
'sphinx_gallery.gen_gallery',
'sphinx_copybutton',
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.duration",
"sphinx_gallery.gen_gallery",
"sphinx_copybutton",
]
sphinx_gallery_conf = {
'examples_dirs': '../../gallery/', # path to your example scripts
'gallery_dirs': 'auto_examples', # path to where to save gallery generated output
'backreferences_dir': 'gen_modules/backreferences',
'doc_module': ('torchvision',),
"examples_dirs": "../../gallery/", # path to your example scripts
"gallery_dirs": "auto_examples", # path to where to save gallery generated output
"backreferences_dir": "gen_modules/backreferences",
"doc_module": ("torchvision",),
}
napoleon_use_ivar = True
......@@ -59,22 +59,22 @@ napoleon_google_docstring = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = {
'.rst': 'restructuredtext',
".rst": "restructuredtext",
}
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'Torchvision'
copyright = '2017-present, Torch Contributors'
author = 'Torch Contributors'
project = "Torchvision"
copyright = "2017-present, Torch Contributors"
author = "Torch Contributors"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
......@@ -82,10 +82,10 @@ author = 'Torch Contributors'
#
# The short X.Y version.
# TODO: change to [:2] at v1.0
version = 'main (' + torchvision.__version__ + ' )'
version = "main (" + torchvision.__version__ + " )"
# The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected
release = 'main'
release = "main"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
......@@ -100,7 +100,7 @@ language = None
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
......@@ -111,7 +111,7 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'pytorch_sphinx_theme'
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
......@@ -119,30 +119,30 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# documentation.
#
html_theme_options = {
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
'pytorch_project': 'docs',
'navigation_with_keys': True,
'analytics_id': 'UA-117752657-2',
"collapse_navigation": False,
"display_version": True,
"logo_only": True,
"pytorch_project": "docs",
"navigation_with_keys": True,
"analytics_id": "UA-117752657-2",
}
html_logo = '_static/img/pytorch-logo-dark.svg'
html_logo = "_static/img/pytorch-logo-dark.svg"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
# TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed
html_css_files = [
'css/custom_torchvision.css',
"css/custom_torchvision.css",
]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
htmlhelp_basename = "PyTorchdoc"
# -- Options for LaTeX output ---------------------------------------------
......@@ -150,15 +150,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
......@@ -169,8 +166,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'torchvision Documentation',
'Torch Contributors', 'manual'),
(master_doc, "pytorch.tex", "torchvision Documentation", "Torch Contributors", "manual"),
]
......@@ -178,10 +174,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchvision', 'torchvision Documentation',
[author], 1)
]
man_pages = [(master_doc, "torchvision", "torchvision Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
......@@ -190,27 +183,33 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchvision', 'torchvision Documentation',
author, 'torchvision', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
"torchvision",
"torchvision Documentation",
author,
"torchvision",
"One line description of project.",
"Miscellaneous",
),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'PIL': ('https://pillow.readthedocs.io/en/stable/', None),
'matplotlib': ('https://matplotlib.org/stable/', None),
"python": ("https://docs.python.org/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes
from sphinx.util.docfields import TypedField
def patched_make_field(self, types, domain, items, **kw):
......@@ -220,40 +219,39 @@ def patched_make_field(self, types, domain, items, **kw):
# type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong))
if fieldarg in types:
par += nodes.Text(' (')
par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int')
typename = typename.replace('long', 'python:long')
typename = typename.replace('float', 'python:float')
typename = typename.replace('type', 'python:type')
par.extend(self.make_xrefs(self.typerolename, domain, typename,
addnodes.literal_emphasis, **kw))
typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace("int", "python:int")
typename = typename.replace("long", "python:long")
typename = typename.replace("float", "python:float")
typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw))
else:
par += fieldtype
par += nodes.Text(')')
par += nodes.Text(' -- ')
par += nodes.Text(")")
par += nodes.Text(" -- ")
par += content
return par
fieldname = nodes.field_name('', self.label)
fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content)
else:
bodynode = self.list_type()
for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body("", bodynode)
return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field
......@@ -286,4 +284,4 @@ def inject_minigalleries(app, what, name, obj, options, lines):
def setup(app):
app.connect('autodoc-process-docstring', inject_minigalleries)
app.connect("autodoc-process-docstring", inject_minigalleries)
# Optional list of dependencies required by the package
dependencies = ['torch']
dependencies = ["torch"]
# classification
from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.efficientnet import (
efficientnet_b0,
efficientnet_b1,
efficientnet_b2,
efficientnet_b3,
efficientnet_b4,
efficientnet_b5,
efficientnet_b6,
efficientnet_b7,
)
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.inception import inception_v3
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7
from torchvision.models.regnet import regnet_y_400mf, regnet_y_800mf, \
regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, \
regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, \
regnet_x_16gf, regnet_x_32gf
from torchvision.models.regnet import (
regnet_y_400mf,
regnet_y_800mf,
regnet_y_1_6gf,
regnet_y_3_2gf,
regnet_y_8gf,
regnet_y_16gf,
regnet_y_32gf,
regnet_x_400mf,
regnet_x_800mf,
regnet_x_1_6gf,
regnet_x_3_2gf,
regnet_x_8gf,
regnet_x_16gf,
regnet_x_32gf,
)
from torchvision.models.resnet import (
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
resnext50_32x4d,
resnext101_32x8d,
wide_resnet50_2,
wide_resnet101_2,
)
# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large
from torchvision.models.segmentation import (
fcn_resnet50,
fcn_resnet101,
deeplabv3_resnet50,
deeplabv3_resnet101,
deeplabv3_mobilenet_v3_large,
lraspp_mobilenet_v3_large,
)
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
......@@ -5,11 +5,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
print(torch.__version__)
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=True,
box_score_thresh=0.7,
rpn_post_nms_top_n_test=100,
rpn_score_thresh=0.4,
rpn_pre_nms_top_n_test=150)
pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
)
model.eval()
script_model = torch.jit.script(model)
......
......@@ -2,46 +2,63 @@
"""Helper script to package wheels and relocate binaries."""
# Standard library imports
import os
import io
import sys
import glob
import shutil
import zipfile
import hashlib
import io
# Standard library imports
import os
import os.path as osp
import platform
import shutil
import subprocess
import os.path as osp
import sys
import zipfile
from base64 import urlsafe_b64encode
# Third party imports
if sys.platform == 'linux':
if sys.platform == "linux":
from auditwheel.lddtree import lddtree
from wheel.bdist_wheel import get_abi_tag
ALLOWLIST = {
'libgcc_s.so.1', 'libstdc++.so.6', 'libm.so.6',
'libdl.so.2', 'librt.so.1', 'libc.so.6',
'libnsl.so.1', 'libutil.so.1', 'libpthread.so.0',
'libresolv.so.2', 'libX11.so.6', 'libXext.so.6',
'libXrender.so.1', 'libICE.so.6', 'libSM.so.6',
'libGL.so.1', 'libgobject-2.0.so.0', 'libgthread-2.0.so.0',
'libglib-2.0.so.0', 'ld-linux-x86-64.so.2', 'ld-2.17.so'
"libgcc_s.so.1",
"libstdc++.so.6",
"libm.so.6",
"libdl.so.2",
"librt.so.1",
"libc.so.6",
"libnsl.so.1",
"libutil.so.1",
"libpthread.so.0",
"libresolv.so.2",
"libX11.so.6",
"libXext.so.6",
"libXrender.so.1",
"libICE.so.6",
"libSM.so.6",
"libGL.so.1",
"libgobject-2.0.so.0",
"libgthread-2.0.so.0",
"libglib-2.0.so.0",
"ld-linux-x86-64.so.2",
"ld-2.17.so",
}
WINDOWS_ALLOWLIST = {
'MSVCP140.dll', 'KERNEL32.dll',
'VCRUNTIME140_1.dll', 'VCRUNTIME140.dll',
'api-ms-win-crt-heap-l1-1-0.dll',
'api-ms-win-crt-runtime-l1-1-0.dll',
'api-ms-win-crt-stdio-l1-1-0.dll',
'api-ms-win-crt-filesystem-l1-1-0.dll',
'api-ms-win-crt-string-l1-1-0.dll',
'api-ms-win-crt-environment-l1-1-0.dll',
'api-ms-win-crt-math-l1-1-0.dll',
'api-ms-win-crt-convert-l1-1-0.dll'
"MSVCP140.dll",
"KERNEL32.dll",
"VCRUNTIME140_1.dll",
"VCRUNTIME140.dll",
"api-ms-win-crt-heap-l1-1-0.dll",
"api-ms-win-crt-runtime-l1-1-0.dll",
"api-ms-win-crt-stdio-l1-1-0.dll",
"api-ms-win-crt-filesystem-l1-1-0.dll",
"api-ms-win-crt-string-l1-1-0.dll",
"api-ms-win-crt-environment-l1-1-0.dll",
"api-ms-win-crt-math-l1-1-0.dll",
"api-ms-win-crt-convert-l1-1-0.dll",
}
......@@ -64,20 +81,18 @@ def rehash(path, blocksize=1 << 20):
"""Return (hash, length) for path using hashlib.sha256()"""
h = hashlib.sha256()
length = 0
with open(path, 'rb') as f:
with open(path, "rb") as f:
for block in read_chunks(f, size=blocksize):
length += len(block)
h.update(block)
digest = 'sha256=' + urlsafe_b64encode(
h.digest()
).decode('latin1').rstrip('=')
digest = "sha256=" + urlsafe_b64encode(h.digest()).decode("latin1").rstrip("=")
# unicode/str python2 issues
return (digest, str(length)) # type: ignore
def unzip_file(file, dest):
"""Decompress zip `file` into directory `dest`."""
with zipfile.ZipFile(file, 'r') as zip_ref:
with zipfile.ZipFile(file, "r") as zip_ref:
zip_ref.extractall(dest)
......@@ -88,8 +103,7 @@ def is_program_installed(basename):
On macOS systems, a .app is considered installed if
it exists.
"""
if (sys.platform == 'darwin' and basename.endswith('.app') and
osp.exists(basename)):
if sys.platform == "darwin" and basename.endswith(".app") and osp.exists(basename):
return basename
for path in os.environ["PATH"].split(os.pathsep):
......@@ -105,9 +119,9 @@ def find_program(basename):
(return None if not found)
"""
names = [basename]
if os.name == 'nt':
if os.name == "nt":
# Windows platforms
extensions = ('.exe', '.bat', '.cmd', '.dll')
extensions = (".exe", ".bat", ".cmd", ".dll")
if not basename.endswith(extensions):
names = [basename + ext for ext in extensions] + [basename]
for name in names:
......@@ -118,19 +132,18 @@ def find_program(basename):
def patch_new_path(library_path, new_dir):
library = osp.basename(library_path)
name, *rest = library.split('.')
rest = '.'.join(rest)
hash_id = hashlib.sha256(library_path.encode('utf-8')).hexdigest()[:8]
new_name = '.'.join([name, hash_id, rest])
name, *rest = library.split(".")
rest = ".".join(rest)
hash_id = hashlib.sha256(library_path.encode("utf-8")).hexdigest()[:8]
new_name = ".".join([name, hash_id, rest])
return osp.join(new_dir, new_name)
def find_dll_dependencies(dumpbin, binary):
out = subprocess.run([dumpbin, "/dependents", binary],
stdout=subprocess.PIPE)
out = out.stdout.strip().decode('utf-8')
start_index = out.find('dependencies:') + len('dependencies:')
end_index = out.find('Summary')
out = subprocess.run([dumpbin, "/dependents", binary], stdout=subprocess.PIPE)
out = out.stdout.strip().decode("utf-8")
start_index = out.find("dependencies:") + len("dependencies:")
end_index = out.find("Summary")
dlls = out[start_index:end_index].strip()
dlls = dlls.split(os.linesep)
dlls = [dll.strip() for dll in dlls]
......@@ -145,13 +158,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
rename and copy them into the wheel while updating their respective rpaths.
"""
print('Relocating {0}'.format(binary))
print("Relocating {0}".format(binary))
binary_path = osp.join(output_library, binary)
ld_tree = lddtree(binary_path)
tree_libs = ld_tree['libs']
tree_libs = ld_tree["libs"]
binary_queue = [(n, binary) for n in ld_tree['needed']]
binary_queue = [(n, binary) for n in ld_tree["needed"]]
binary_paths = {binary: binary_path}
binary_dependencies = {}
......@@ -160,13 +173,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
library_info = tree_libs[library]
print(library)
if library_info['path'] is None:
print('Omitting {0}'.format(library))
if library_info["path"] is None:
print("Omitting {0}".format(library))
continue
if library in ALLOWLIST:
# Omit glibc/gcc/system libraries
print('Omitting {0}'.format(library))
print("Omitting {0}".format(library))
continue
parent_dependencies = binary_dependencies.get(parent, [])
......@@ -176,11 +189,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if library in binary_paths:
continue
binary_paths[library] = library_info['path']
binary_queue += [(n, library) for n in library_info['needed']]
binary_paths[library] = library_info["path"]
binary_queue += [(n, library) for n in library_info["needed"]]
print('Copying dependencies to wheel directory')
new_libraries_path = osp.join(output_dir, 'torchvision.libs')
print("Copying dependencies to wheel directory")
new_libraries_path = osp.join(output_dir, "torchvision.libs")
os.makedirs(new_libraries_path)
new_names = {binary: binary_path}
......@@ -189,11 +202,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if library != binary:
library_path = binary_paths[library]
new_library_path = patch_new_path(library_path, new_libraries_path)
print('{0} -> {1}'.format(library, new_library_path))
print("{0} -> {1}".format(library, new_library_path))
shutil.copyfile(library_path, new_library_path)
new_names[library] = new_library_path
print('Updating dependency names by new files')
print("Updating dependency names by new files")
for library in binary_paths:
if library != binary:
if library not in binary_dependencies:
......@@ -202,59 +215,26 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
new_library_name = new_names[library]
for dep in library_dependencies:
new_dep = osp.basename(new_names[dep])
print('{0}: {1} -> {2}'.format(library, dep, new_dep))
print("{0}: {1} -> {2}".format(library, dep, new_dep))
subprocess.check_output(
[
patchelf,
'--replace-needed',
dep,
new_dep,
new_library_name
],
cwd=new_libraries_path)
print('Updating library rpath')
subprocess.check_output(
[
patchelf,
'--set-rpath',
"$ORIGIN",
new_library_name
],
cwd=new_libraries_path)
subprocess.check_output(
[
patchelf,
'--print-rpath',
new_library_name
],
cwd=new_libraries_path)
[patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path
)
print("Updating library rpath")
subprocess.check_output([patchelf, "--set-rpath", "$ORIGIN", new_library_name], cwd=new_libraries_path)
subprocess.check_output([patchelf, "--print-rpath", new_library_name], cwd=new_libraries_path)
print("Update library dependencies")
library_dependencies = binary_dependencies[binary]
for dep in library_dependencies:
new_dep = osp.basename(new_names[dep])
print('{0}: {1} -> {2}'.format(binary, dep, new_dep))
subprocess.check_output(
[
patchelf,
'--replace-needed',
dep,
new_dep,
binary
],
cwd=output_library)
print('Update library rpath')
print("{0}: {1} -> {2}".format(binary, dep, new_dep))
subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library)
print("Update library rpath")
subprocess.check_output(
[
patchelf,
'--set-rpath',
"$ORIGIN:$ORIGIN/../torchvision.libs",
binary_path
],
cwd=output_library
[patchelf, "--set-rpath", "$ORIGIN:$ORIGIN/../torchvision.libs", binary_path], cwd=output_library
)
......@@ -265,7 +245,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
Given a shared library, find the transitive closure of its dependencies,
rename and copy them into the wheel.
"""
print('Relocating {0}'.format(binary))
print("Relocating {0}".format(binary))
binary_path = osp.join(output_library, binary)
library_dlls = find_dll_dependencies(dumpbin, binary_path)
......@@ -275,19 +255,19 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
while binary_queue != []:
library, parent = binary_queue.pop(0)
if library in WINDOWS_ALLOWLIST or library.startswith('api-ms-win'):
print('Omitting {0}'.format(library))
if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"):
print("Omitting {0}".format(library))
continue
library_path = find_program(library)
if library_path is None:
print('{0} not found'.format(library))
print("{0} not found".format(library))
continue
if osp.basename(osp.dirname(library_path)) == 'system32':
if osp.basename(osp.dirname(library_path)) == "system32":
continue
print('{0}: {1}'.format(library, library_path))
print("{0}: {1}".format(library, library_path))
parent_dependencies = binary_dependencies.get(parent, [])
parent_dependencies.append(library)
binary_dependencies[parent] = parent_dependencies
......@@ -299,55 +279,56 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
downstream_dlls = find_dll_dependencies(dumpbin, library_path)
binary_queue += [(n, library) for n in downstream_dlls]
print('Copying dependencies to wheel directory')
package_dir = osp.join(output_dir, 'torchvision')
print("Copying dependencies to wheel directory")
package_dir = osp.join(output_dir, "torchvision")
for library in binary_paths:
if library != binary:
library_path = binary_paths[library]
new_library_path = osp.join(package_dir, library)
print('{0} -> {1}'.format(library, new_library_path))
print("{0} -> {1}".format(library, new_library_path))
shutil.copyfile(library_path, new_library_path)
def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
"""Create RECORD file and compress wheel distribution."""
print('Update RECORD file in wheel')
dist_info = glob.glob(osp.join(output_dir, '*.dist-info'))[0]
record_file = osp.join(dist_info, 'RECORD')
print("Update RECORD file in wheel")
dist_info = glob.glob(osp.join(output_dir, "*.dist-info"))[0]
record_file = osp.join(dist_info, "RECORD")
with open(record_file, 'w') as f:
with open(record_file, "w") as f:
for root, _, files in os.walk(output_dir):
for this_file in files:
full_file = osp.join(root, this_file)
rel_file = osp.relpath(full_file, output_dir)
if full_file == record_file:
f.write('{0},,\n'.format(rel_file))
f.write("{0},,\n".format(rel_file))
else:
digest, size = rehash(full_file)
f.write('{0},{1},{2}\n'.format(rel_file, digest, size))
f.write("{0},{1},{2}\n".format(rel_file, digest, size))
print('Compressing wheel')
print("Compressing wheel")
base_wheel_name = osp.join(wheel_dir, wheel_name)
shutil.make_archive(base_wheel_name, 'zip', output_dir)
shutil.make_archive(base_wheel_name, "zip", output_dir)
os.remove(wheel)
shutil.move('{0}.zip'.format(base_wheel_name), wheel)
shutil.move("{0}.zip".format(base_wheel_name), wheel)
shutil.rmtree(output_dir)
def patch_linux():
# Get patchelf location
patchelf = find_program('patchelf')
patchelf = find_program("patchelf")
if patchelf is None:
raise FileNotFoundError('Patchelf was not found in the system, please'
' make sure that is available on the PATH.')
raise FileNotFoundError(
"Patchelf was not found in the system, please" " make sure that is available on the PATH."
)
# Find wheel
print('Finding wheels...')
wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl'))
output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process')
print("Finding wheels...")
wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl"))
output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process")
image_binary = 'image.so'
video_binary = 'video_reader.so'
image_binary = "image.so"
video_binary = "video_reader.so"
torchvision_binaries = [image_binary, video_binary]
for wheel in wheels:
if osp.exists(output_dir):
......@@ -355,37 +336,37 @@ def patch_linux():
os.makedirs(output_dir)
print('Unzipping wheel...')
print("Unzipping wheel...")
wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel)
print('{0}'.format(wheel_file))
print("{0}".format(wheel_file))
wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir)
print('Finding ELF dependencies...')
output_library = osp.join(output_dir, 'torchvision')
print("Finding ELF dependencies...")
output_library = osp.join(output_dir, "torchvision")
for binary in torchvision_binaries:
if osp.exists(osp.join(output_library, binary)):
relocate_elf_library(
patchelf, output_dir, output_library, binary)
relocate_elf_library(patchelf, output_dir, output_library, binary)
compress_wheel(output_dir, wheel, wheel_dir, wheel_name)
def patch_win():
# Get dumpbin location
dumpbin = find_program('dumpbin')
dumpbin = find_program("dumpbin")
if dumpbin is None:
raise FileNotFoundError('Dumpbin was not found in the system, please'
' make sure that is available on the PATH.')
raise FileNotFoundError(
"Dumpbin was not found in the system, please" " make sure that is available on the PATH."
)
# Find wheel
print('Finding wheels...')
wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl'))
output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process')
print("Finding wheels...")
wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl"))
output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process")
image_binary = 'image.pyd'
video_binary = 'video_reader.pyd'
image_binary = "image.pyd"
video_binary = "video_reader.pyd"
torchvision_binaries = [image_binary, video_binary]
for wheel in wheels:
if osp.exists(output_dir):
......@@ -393,25 +374,24 @@ def patch_win():
os.makedirs(output_dir)
print('Unzipping wheel...')
print("Unzipping wheel...")
wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel)
print('{0}'.format(wheel_file))
print("{0}".format(wheel_file))
wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir)
print('Finding DLL/PE dependencies...')
output_library = osp.join(output_dir, 'torchvision')
print("Finding DLL/PE dependencies...")
output_library = osp.join(output_dir, "torchvision")
for binary in torchvision_binaries:
if osp.exists(osp.join(output_library, binary)):
relocate_dll_library(
dumpbin, output_dir, output_library, binary)
relocate_dll_library(dumpbin, output_dir, output_library, binary)
compress_wheel(output_dir, wheel, wheel_dir, wheel_name)
if __name__ == '__main__':
if sys.platform == 'linux':
if __name__ == "__main__":
if sys.platform == "linux":
patch_linux()
elif sys.platform == 'win32':
elif sys.platform == "win32":
patch_win()
[tool.usort]
first_party_detection = false
[tool.black]
line-length = 120
target-version = ["py36"]
[tool.ufmt]
excludes = [
"gallery",
]
......@@ -4,8 +4,15 @@ from torchvision.transforms.functional import InterpolationMode
class ClassificationPresetTrain:
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5,
auto_augment_policy=None, random_erase_prob=0.0):
def __init__(
self,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
......@@ -17,11 +24,13 @@ class ClassificationPresetTrain:
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
])
trans.extend(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
......@@ -32,16 +41,24 @@ class ClassificationPresetTrain:
class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR):
self.transforms = transforms.Compose([
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
])
def __init__(
self,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
):
self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
def __call__(self, img):
return self.transforms(img)
This diff is collapsed.
import copy
import datetime
import os
import time
import copy
import torch
import torch.quantization
import torch.utils.data
from torch import nn
import torchvision
import torch.quantization
import utils
from torch import nn
from train import train_one_epoch, evaluate, load_data
......@@ -20,8 +20,7 @@ def main(args):
print(args)
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed "
"on distributed mode")
raise RuntimeError("Post training quantization example should not be performed " "on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
......@@ -33,17 +32,17 @@ def main(args):
# Data loading code
print("Loading data")
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.eval_batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
......@@ -59,12 +58,10 @@ def main(args):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=args.lr_step_size,
gamma=args.lr_gamma)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
criterion = nn.CrossEntropyLoss()
model_without_ddp = model
......@@ -73,21 +70,19 @@ def main(args):
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.post_training_quantize:
# perform calibration on a subset of the training dataset
# for that, create a subset of the training dataset
ds = torch.utils.data.Subset(
dataset,
indices=list(range(args.batch_size * args.num_calibration_batches)))
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
data_loader_calibration = torch.utils.data.DataLoader(
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
)
model.eval()
model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend)
......@@ -97,10 +92,9 @@ def main(args):
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True)
if args.output_dir:
print('Saving quantized model')
print("Saving quantized model")
if utils.is_main_process():
torch.save(model.state_dict(), os.path.join(args.output_dir,
'quantized_post_train_model.pth'))
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
print("Evaluating post-training quantized model")
evaluate(model, criterion, data_loader_test, device=device)
return
......@@ -115,107 +109,103 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
print('Starting training for epoch', epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
args.print_freq)
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step()
with torch.no_grad():
if epoch >= args.num_observer_update_epochs:
print('Disabling observer for subseq epochs, epoch = ', epoch)
print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print('Freezing BN for subseq epochs, epoch = ', epoch)
print("Freezing BN for subseq epochs, epoch = ", epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print('Evaluate QAT model')
print("Evaluate QAT model")
evaluate(model, criterion, data_loader_test, device=device)
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device('cpu'))
quantized_eval_model.to(torch.device("cpu"))
torch.quantization.convert(quantized_eval_model, inplace=True)
print('Evaluate Quantized model')
evaluate(quantized_eval_model, criterion, data_loader_test,
device=torch.device('cpu'))
print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
model.train()
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'eval_model': quantized_eval_model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
print('Saving models after epoch ', epoch)
"model": model_without_ddp.state_dict(),
"eval_model": quantized_eval_model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
print("Saving models after epoch ", epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
print("Training time {}".format(total_time_str))
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Quantized Classification Training', add_help=add_help)
parser.add_argument('--data-path',
default='/datasets01/imagenet_full_size/061417/',
help='dataset')
parser.add_argument('--model',
default='mobilenet_v2',
help='model')
parser.add_argument('--backend',
default='qnnpack',
help='fbgemm or qnnpack')
parser.add_argument('--device',
default='cuda',
help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int,
help='batch size for calibration/training')
parser.add_argument('--eval-batch-size', default=128, type=int,
help='batch size for evaluation')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--num-observer-update-epochs',
default=4, type=int, metavar='N',
help='number of total epochs to update observers')
parser.add_argument('--num-batch-norm-update-epochs', default=3,
type=int, metavar='N',
help='number of total epochs to update batch norm stats')
parser.add_argument('--num-calibration-batches',
default=32, type=int, metavar='N',
help='number of batches of training set for \
observer calibration ')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr',
default=0.0001, type=float,
help='initial learning rate')
parser.add_argument('--momentum',
default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-step-size', default=30, type=int,
help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int,
help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
parser.add_argument("--model", default="mobilenet_v2", help="model")
parser.add_argument("--backend", default="qnnpack", help="fbgemm or qnnpack")
parser.add_argument("--device", default="cuda", help="device")
parser.add_argument("-b", "--batch-size", default=32, type=int, help="batch size for calibration/training")
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument(
"--num-observer-update-epochs",
default=4,
type=int,
metavar="N",
help="number of total epochs to update observers",
)
parser.add_argument(
"--num-batch-norm-update-epochs",
default=3,
type=int,
metavar="N",
help="number of total epochs to update batch norm stats",
)
parser.add_argument(
"--num-calibration-batches",
default=32,
type=int,
metavar="N",
help="number of batches of training set for \
observer calibration ",
)
parser.add_argument(
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
)
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
"--wd",
"--weight-decay",
default=1e-4,
type=float,
metavar="W",
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", help="path where to save")
parser.add_argument("--resume", default="", help="resume from checkpoint")
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
......@@ -243,11 +233,8 @@ def get_args_parser(add_help=True):
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url',
default='env://',
help='url used to set up distributed training')
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
return parser
......
import math
import torch
from typing import Tuple
import torch
from torch import Tensor
from torchvision.transforms import functional as F
......@@ -19,9 +19,7 @@ class RandomMixup(torch.nn.Module):
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
......@@ -45,7 +43,7 @@ class RandomMixup(torch.nn.Module):
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
......@@ -74,12 +72,12 @@ class RandomMixup(torch.nn.Module):
return batch, target
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
s = self.__class__.__name__ + "("
s += "num_classes={num_classes}"
s += ", p={p}"
s += ", alpha={alpha}"
s += ", inplace={inplace}"
s += ")"
return s.format(**self.__dict__)
......@@ -97,9 +95,7 @@ class RandomCutmix(torch.nn.Module):
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
......@@ -123,7 +119,7 @@ class RandomCutmix(torch.nn.Module):
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
......@@ -166,10 +162,10 @@ class RandomCutmix(torch.nn.Module):
return batch, target
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
s = self.__class__.__name__ + "("
s += "num_classes={num_classes}"
s += ", p={p}"
s += ", alpha={alpha}"
s += ", inplace={inplace}"
s += ")"
return s.format(**self.__dict__)
from collections import defaultdict, deque, OrderedDict
import copy
import datetime
import errno
import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict
import torch
import torch.distributed as dist
import errno
import os
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
......@@ -34,7 +34,7 @@ class SmoothedValue(object):
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
......@@ -65,11 +65,8 @@ class SmoothedValue(object):
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class MetricLogger(object):
......@@ -89,15 +86,12 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
......@@ -110,31 +104,28 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
......@@ -144,21 +135,28 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
print("{} Total time: {}".format(header, total_time_str))
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
......@@ -167,9 +165,9 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device='cpu'):
ema_avg = (lambda avg_model_param, model_param, num_averaged:
decay * avg_model_param + (1 - decay) * model_param)
def __init__(self, model, decay, device="cpu"):
ema_avg = lambda avg_model_param, model_param, num_averaged: decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg)
def update_parameters(self, model):
......@@ -179,8 +177,7 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)))
self.n_averaged += 1
......@@ -216,10 +213,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
......@@ -256,28 +254,28 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
setup_for_distributed(args.rank == 0)
......@@ -300,9 +298,7 @@ def average_checkpoints(inputs):
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
)
# Copies over the settings from the first checkpoint
if new_state is None:
......@@ -336,7 +332,7 @@ def average_checkpoints(inputs):
return new_state
def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True):
def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
"""
This method can be used to prepare weights files for new models. It receives as
input a model architecture and a checkpoint from the training script and produces
......@@ -382,7 +378,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T
# Deep copy to avoid side-effects on the model object.
model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
......
......@@ -5,10 +5,9 @@ from contextlib import redirect_stdout
import numpy as np
import pycocotools.mask as mask_util
import torch
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
class CocoEvaluator:
......@@ -104,8 +103,7 @@ class CocoEvaluator:
labels = prediction["labels"].tolist()
rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
for mask in masks
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
......@@ -141,7 +139,7 @@ class CocoEvaluator:
{
"image_id": original_id,
"category_id": labels[k],
'keypoints': keypoint,
"keypoints": keypoint,
"score": scores[k],
}
for k, keypoint in enumerate(keypoints)
......
import copy
import os
from PIL import Image
import torch
import torch.utils.data
import torchvision
import transforms as T
from PIL import Image
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
import transforms as T
class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
......@@ -56,7 +54,7 @@ class ConvertCocoPolysToMask(object):
anno = target["annotations"]
anno = [obj for obj in anno if obj['iscrowd'] == 0]
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
......@@ -147,7 +145,7 @@ def convert_to_coco_api(ds):
coco_ds = COCO()
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
ann_id = 1
dataset = {'images': [], 'categories': [], 'annotations': []}
dataset = {"images": [], "categories": [], "annotations": []}
categories = set()
for img_idx in range(len(ds)):
# find better way to get target
......@@ -155,41 +153,41 @@ def convert_to_coco_api(ds):
img, targets = ds[img_idx]
image_id = targets["image_id"].item()
img_dict = {}
img_dict['id'] = image_id
img_dict['height'] = img.shape[-2]
img_dict['width'] = img.shape[-1]
dataset['images'].append(img_dict)
img_dict["id"] = image_id
img_dict["height"] = img.shape[-2]
img_dict["width"] = img.shape[-1]
dataset["images"].append(img_dict)
bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist()
labels = targets['labels'].tolist()
areas = targets['area'].tolist()
iscrowd = targets['iscrowd'].tolist()
if 'masks' in targets:
masks = targets['masks']
labels = targets["labels"].tolist()
areas = targets["area"].tolist()
iscrowd = targets["iscrowd"].tolist()
if "masks" in targets:
masks = targets["masks"]
# make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if 'keypoints' in targets:
keypoints = targets['keypoints']
if "keypoints" in targets:
keypoints = targets["keypoints"]
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes)
for i in range(num_objs):
ann = {}
ann['image_id'] = image_id
ann['bbox'] = bboxes[i]
ann['category_id'] = labels[i]
ann["image_id"] = image_id
ann["bbox"] = bboxes[i]
ann["category_id"] = labels[i]
categories.add(labels[i])
ann['area'] = areas[i]
ann['iscrowd'] = iscrowd[i]
ann['id'] = ann_id
if 'masks' in targets:
ann["area"] = areas[i]
ann["iscrowd"] = iscrowd[i]
ann["id"] = ann_id
if "masks" in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if 'keypoints' in targets:
ann['keypoints'] = keypoints[i]
ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
dataset['annotations'].append(ann)
if "keypoints" in targets:
ann["keypoints"] = keypoints[i]
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
dataset["annotations"].append(ann)
ann_id += 1
dataset['categories'] = [{'id': i} for i in sorted(categories)]
dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco_ds.dataset = dataset
coco_ds.createIndex()
return coco_ds
......@@ -220,7 +218,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return img, target
def get_coco(root, image_set, transforms, mode='instances'):
def get_coco(root, image_set, transforms, mode="instances"):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
......
import math
import sys
import time
import torch
import torch
import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator
import utils
from coco_eval import CocoEvaluator
from coco_utils import get_coco_api_from_dataset
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
lr_scheduler = None
if epoch == 0:
warmup_factor = 1. / 1000
warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor,
total_iters=warmup_iters)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
......@@ -76,7 +76,7 @@ def evaluate(model, data_loader, device):
cpu_device = torch.device("cpu")
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
header = "Test:"
coco = get_coco_api_from_dataset(data_loader.dataset)
iou_types = _get_iou_types(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment