Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
...@@ -201,7 +201,7 @@ jobs: ...@@ -201,7 +201,7 @@ jobs:
pip install --user --progress-bar off types-requests pip install --user --progress-bar off types-requests
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off git+https://github.com/pytorch/data.git pip install --user --progress-bar off git+https://github.com/pytorch/data.git
pip install --user --progress-bar off --editable . pip install --user --progress-bar off --no-build-isolation --editable .
mypy --config-file mypy.ini mypy --config-file mypy.ini
docstring_parameters_sync: docstring_parameters_sync:
...@@ -235,7 +235,7 @@ jobs: ...@@ -235,7 +235,7 @@ jobs:
command: | command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports # need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off . pip install --user --progress-bar off --no-build-isolation .
pip install pytest pip install pytest
python test/test_hub.py python test/test_hub.py
...@@ -248,7 +248,7 @@ jobs: ...@@ -248,7 +248,7 @@ jobs:
command: | command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports # need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off . pip install --user --progress-bar off --no-build-isolation .
pip install --user onnx pip install --user onnx
pip install --user onnxruntime pip install --user onnxruntime
pip install --user pytest pip install --user pytest
......
...@@ -201,7 +201,7 @@ jobs: ...@@ -201,7 +201,7 @@ jobs:
pip install --user --progress-bar off types-requests pip install --user --progress-bar off types-requests
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off git+https://github.com/pytorch/data.git pip install --user --progress-bar off git+https://github.com/pytorch/data.git
pip install --user --progress-bar off --editable . pip install --user --progress-bar off --no-build-isolation --editable .
mypy --config-file mypy.ini mypy --config-file mypy.ini
docstring_parameters_sync: docstring_parameters_sync:
...@@ -235,7 +235,7 @@ jobs: ...@@ -235,7 +235,7 @@ jobs:
command: | command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports # need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off . pip install --user --progress-bar off --no-build-isolation .
pip install pytest pip install pytest
python test/test_hub.py python test/test_hub.py
...@@ -248,7 +248,7 @@ jobs: ...@@ -248,7 +248,7 @@ jobs:
command: | command: |
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# need to install torchvision dependencies due to transitive imports # need to install torchvision dependencies due to transitive imports
pip install --user --progress-bar off . pip install --user --progress-bar off --no-build-isolation .
pip install --user onnx pip install --user onnx
pip install --user onnxruntime pip install --user onnxruntime
pip install --user pytest pip install --user pytest
......
...@@ -14,10 +14,11 @@ See this comment for design rationale: ...@@ -14,10 +14,11 @@ See this comment for design rationale:
https://github.com/pytorch/vision/pull/1321#issuecomment-531033978 https://github.com/pytorch/vision/pull/1321#issuecomment-531033978
""" """
import os.path
import jinja2 import jinja2
from jinja2 import select_autoescape
import yaml import yaml
import os.path from jinja2 import select_autoescape
PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
...@@ -25,57 +26,66 @@ PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] ...@@ -25,57 +26,66 @@ PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
RC_PATTERN = r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" RC_PATTERN = r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
def build_workflows(prefix='', filter_branch=None, upload=False, indentation=6, windows_latest_only=False): def build_workflows(prefix="", filter_branch=None, upload=False, indentation=6, windows_latest_only=False):
w = [] w = []
for btype in ["wheel", "conda"]: for btype in ["wheel", "conda"]:
for os_type in ["linux", "macos", "win"]: for os_type in ["linux", "macos", "win"]:
python_versions = PYTHON_VERSIONS python_versions = PYTHON_VERSIONS
cu_versions_dict = {"linux": ["cpu", "cu102", "cu111", "cu113", "rocm4.1", "rocm4.2"], cu_versions_dict = {
"win": ["cpu", "cu102", "cu111", "cu113"], "linux": ["cpu", "cu102", "cu111", "cu113", "rocm4.1", "rocm4.2"],
"macos": ["cpu"]} "win": ["cpu", "cu102", "cu111", "cu113"],
"macos": ["cpu"],
}
cu_versions = cu_versions_dict[os_type] cu_versions = cu_versions_dict[os_type]
for python_version in python_versions: for python_version in python_versions:
for cu_version in cu_versions: for cu_version in cu_versions:
# ROCm conda packages not yet supported # ROCm conda packages not yet supported
if cu_version.startswith('rocm') and btype == "conda": if cu_version.startswith("rocm") and btype == "conda":
continue continue
for unicode in [False]: for unicode in [False]:
fb = filter_branch fb = filter_branch
if windows_latest_only and os_type == "win" and filter_branch is None and \ if (
(python_version != python_versions[-1] or windows_latest_only
(cu_version not in [cu_versions[0], cu_versions[-1]])): and os_type == "win"
and filter_branch is None
and (
python_version != python_versions[-1]
or (cu_version not in [cu_versions[0], cu_versions[-1]])
)
):
fb = "main" fb = "main"
if not fb and (os_type == 'linux' and if not fb and (
cu_version == 'cpu' and os_type == "linux" and cu_version == "cpu" and btype == "wheel" and python_version == "3.7"
btype == 'wheel' and ):
python_version == '3.7'):
# the fields must match the build_docs "requires" dependency # the fields must match the build_docs "requires" dependency
fb = "/.*/" fb = "/.*/"
w += workflow_pair( w += workflow_pair(
btype, os_type, python_version, cu_version, btype, os_type, python_version, cu_version, unicode, prefix, upload, filter_branch=fb
unicode, prefix, upload, filter_branch=fb) )
if not filter_branch: if not filter_branch:
# Build on every pull request, but upload only on nightly and tags # Build on every pull request, but upload only on nightly and tags
w += build_doc_job('/.*/') w += build_doc_job("/.*/")
w += upload_doc_job('nightly') w += upload_doc_job("nightly")
return indent(indentation, w) return indent(indentation, w)
def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix='', upload=False, *, filter_branch=None): def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix="", upload=False, *, filter_branch=None):
w = [] w = []
unicode_suffix = "u" if unicode else "" unicode_suffix = "u" if unicode else ""
base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}{unicode_suffix}_{cu_version}" base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}{unicode_suffix}_{cu_version}"
w.append(generate_base_workflow( w.append(
base_workflow_name, python_version, cu_version, generate_base_workflow(
unicode, os_type, btype, filter_branch=filter_branch)) base_workflow_name, python_version, cu_version, unicode, os_type, btype, filter_branch=filter_branch
)
)
if upload: if upload:
w.append(generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, filter_branch=filter_branch)) w.append(generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, filter_branch=filter_branch))
if filter_branch == 'nightly' and os_type in ['linux', 'win']: if filter_branch == "nightly" and os_type in ["linux", "win"]:
pydistro = 'pip' if btype == 'wheel' else 'conda' pydistro = "pip" if btype == "wheel" else "conda"
w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, os_type)) w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, os_type))
return w return w
...@@ -85,12 +95,13 @@ def build_doc_job(filter_branch): ...@@ -85,12 +95,13 @@ def build_doc_job(filter_branch):
job = { job = {
"name": "build_docs", "name": "build_docs",
"python_version": "3.7", "python_version": "3.7",
"requires": ["binary_linux_wheel_py3.7_cpu", ], "requires": [
"binary_linux_wheel_py3.7_cpu",
],
} }
if filter_branch: if filter_branch:
job["filters"] = gen_filter_branch_tree(filter_branch, job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN)
tags_list=RC_PATTERN)
return [{"build_docs": job}] return [{"build_docs": job}]
...@@ -99,12 +110,13 @@ def upload_doc_job(filter_branch): ...@@ -99,12 +110,13 @@ def upload_doc_job(filter_branch):
"name": "upload_docs", "name": "upload_docs",
"context": "org-member", "context": "org-member",
"python_version": "3.7", "python_version": "3.7",
"requires": ["build_docs", ], "requires": [
"build_docs",
],
} }
if filter_branch: if filter_branch:
job["filters"] = gen_filter_branch_tree(filter_branch, job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN)
tags_list=RC_PATTERN)
return [{"upload_docs": job}] return [{"upload_docs": job}]
...@@ -122,24 +134,25 @@ manylinux_images = { ...@@ -122,24 +134,25 @@ manylinux_images = {
def get_manylinux_image(cu_version): def get_manylinux_image(cu_version):
if cu_version == "cpu": if cu_version == "cpu":
return "pytorch/manylinux-cuda102" return "pytorch/manylinux-cuda102"
elif cu_version.startswith('cu'): elif cu_version.startswith("cu"):
cu_suffix = cu_version[len('cu'):] cu_suffix = cu_version[len("cu") :]
return f"pytorch/manylinux-cuda{cu_suffix}" return f"pytorch/manylinux-cuda{cu_suffix}"
elif cu_version.startswith('rocm'): elif cu_version.startswith("rocm"):
rocm_suffix = cu_version[len('rocm'):] rocm_suffix = cu_version[len("rocm") :]
return f"pytorch/manylinux-rocm:{rocm_suffix}" return f"pytorch/manylinux-rocm:{rocm_suffix}"
def get_conda_image(cu_version): def get_conda_image(cu_version):
if cu_version == "cpu": if cu_version == "cpu":
return "pytorch/conda-builder:cpu" return "pytorch/conda-builder:cpu"
elif cu_version.startswith('cu'): elif cu_version.startswith("cu"):
cu_suffix = cu_version[len('cu'):] cu_suffix = cu_version[len("cu") :]
return f"pytorch/conda-builder:cuda{cu_suffix}" return f"pytorch/conda-builder:cuda{cu_suffix}"
def generate_base_workflow(base_workflow_name, python_version, cu_version, def generate_base_workflow(
unicode, os_type, btype, *, filter_branch=None): base_workflow_name, python_version, cu_version, unicode, os_type, btype, *, filter_branch=None
):
d = { d = {
"name": base_workflow_name, "name": base_workflow_name,
...@@ -148,7 +161,7 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, ...@@ -148,7 +161,7 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version,
} }
if os_type != "win" and unicode: if os_type != "win" and unicode:
d["unicode_abi"] = '1' d["unicode_abi"] = "1"
if os_type != "win": if os_type != "win":
d["wheel_docker_image"] = get_manylinux_image(cu_version) d["wheel_docker_image"] = get_manylinux_image(cu_version)
...@@ -158,14 +171,12 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, ...@@ -158,14 +171,12 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version,
if filter_branch is not None: if filter_branch is not None:
d["filters"] = { d["filters"] = {
"branches": { "branches": {"only": filter_branch},
"only": filter_branch
},
"tags": { "tags": {
# Using a raw string here to avoid having to escape # Using a raw string here to avoid having to escape
# anything # anything
"only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
} },
} }
w = f"binary_{os_type}_{btype}" w = f"binary_{os_type}_{btype}"
...@@ -186,19 +197,17 @@ def generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, *, ...@@ -186,19 +197,17 @@ def generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, *,
"requires": [base_workflow_name], "requires": [base_workflow_name],
} }
if btype == 'wheel': if btype == "wheel":
d["subfolder"] = "" if os_type == 'macos' else cu_version + "/" d["subfolder"] = "" if os_type == "macos" else cu_version + "/"
if filter_branch is not None: if filter_branch is not None:
d["filters"] = { d["filters"] = {
"branches": { "branches": {"only": filter_branch},
"only": filter_branch
},
"tags": { "tags": {
# Using a raw string here to avoid having to escape # Using a raw string here to avoid having to escape
# anything # anything
"only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
} },
} }
return {f"binary_{btype}_upload": d} return {f"binary_{btype}_upload": d}
...@@ -223,8 +232,7 @@ def generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, pyt ...@@ -223,8 +232,7 @@ def generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, pyt
def indent(indentation, data_list): def indent(indentation, data_list):
return ("\n" + " " * indentation).join( return ("\n" + " " * indentation).join(yaml.dump(data_list, default_flow_style=False).splitlines())
yaml.dump(data_list, default_flow_style=False).splitlines())
def unittest_workflows(indentation=6): def unittest_workflows(indentation=6):
...@@ -239,12 +247,12 @@ def unittest_workflows(indentation=6): ...@@ -239,12 +247,12 @@ def unittest_workflows(indentation=6):
"python_version": python_version, "python_version": python_version,
} }
if device_type == 'gpu': if device_type == "gpu":
if python_version != "3.8": if python_version != "3.8":
job['filters'] = gen_filter_branch_tree('main', 'nightly') job["filters"] = gen_filter_branch_tree("main", "nightly")
job['cu_version'] = 'cu102' job["cu_version"] = "cu102"
else: else:
job['cu_version'] = 'cpu' job["cu_version"] = "cpu"
jobs.append({f"unittest_{os_type}_{device_type}": job}) jobs.append({f"unittest_{os_type}_{device_type}": job})
...@@ -253,20 +261,17 @@ def unittest_workflows(indentation=6): ...@@ -253,20 +261,17 @@ def unittest_workflows(indentation=6):
def cmake_workflows(indentation=6): def cmake_workflows(indentation=6):
jobs = [] jobs = []
python_version = '3.8' python_version = "3.8"
for os_type in ['linux', 'windows', 'macos']: for os_type in ["linux", "windows", "macos"]:
# Skip OSX CUDA # Skip OSX CUDA
device_types = ['cpu', 'gpu'] if os_type != 'macos' else ['cpu'] device_types = ["cpu", "gpu"] if os_type != "macos" else ["cpu"]
for device in device_types: for device in device_types:
job = { job = {"name": f"cmake_{os_type}_{device}", "python_version": python_version}
'name': f'cmake_{os_type}_{device}',
'python_version': python_version
}
job['cu_version'] = 'cu102' if device == 'gpu' else 'cpu' job["cu_version"] = "cu102" if device == "gpu" else "cpu"
if device == 'gpu' and os_type == 'linux': if device == "gpu" and os_type == "linux":
job['wheel_docker_image'] = 'pytorch/manylinux-cuda102' job["wheel_docker_image"] = "pytorch/manylinux-cuda102"
jobs.append({f'cmake_{os_type}_{device}': job}) jobs.append({f"cmake_{os_type}_{device}": job})
return indent(indentation, jobs) return indent(indentation, jobs)
...@@ -275,27 +280,27 @@ def ios_workflows(indentation=6, nightly=False): ...@@ -275,27 +280,27 @@ def ios_workflows(indentation=6, nightly=False):
build_job_names = [] build_job_names = []
name_prefix = "nightly_" if nightly else "" name_prefix = "nightly_" if nightly else ""
env_prefix = "nightly-" if nightly else "" env_prefix = "nightly-" if nightly else ""
for arch, platform in [('x86_64', 'SIMULATOR'), ('arm64', 'OS')]: for arch, platform in [("x86_64", "SIMULATOR"), ("arm64", "OS")]:
name = f'{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}' name = f"{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}"
build_job_names.append(name) build_job_names.append(name)
build_job = { build_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}', "build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}",
'ios_arch': arch, "ios_arch": arch,
'ios_platform': platform, "ios_platform": platform,
'name': name, "name": name,
} }
if nightly: if nightly:
build_job['filters'] = gen_filter_branch_tree('nightly') build_job["filters"] = gen_filter_branch_tree("nightly")
jobs.append({'binary_ios_build': build_job}) jobs.append({"binary_ios_build": build_job})
if nightly: if nightly:
upload_job = { upload_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload', "build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload",
'context': 'org-member', "context": "org-member",
'filters': gen_filter_branch_tree('nightly'), "filters": gen_filter_branch_tree("nightly"),
'requires': build_job_names, "requires": build_job_names,
} }
jobs.append({'binary_ios_upload': upload_job}) jobs.append({"binary_ios_upload": upload_job})
return indent(indentation, jobs) return indent(indentation, jobs)
...@@ -305,23 +310,23 @@ def android_workflows(indentation=6, nightly=False): ...@@ -305,23 +310,23 @@ def android_workflows(indentation=6, nightly=False):
name_prefix = "nightly_" if nightly else "" name_prefix = "nightly_" if nightly else ""
env_prefix = "nightly-" if nightly else "" env_prefix = "nightly-" if nightly else ""
name = f'{name_prefix}binary_libtorchvision_ops_android' name = f"{name_prefix}binary_libtorchvision_ops_android"
build_job_names.append(name) build_job_names.append(name)
build_job = { build_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-android', "build_environment": f"{env_prefix}binary-libtorchvision_ops-android",
'name': name, "name": name,
} }
if nightly: if nightly:
upload_job = { upload_job = {
'build_environment': f'{env_prefix}binary-libtorchvision_ops-android-upload', "build_environment": f"{env_prefix}binary-libtorchvision_ops-android-upload",
'context': 'org-member', "context": "org-member",
'filters': gen_filter_branch_tree('nightly'), "filters": gen_filter_branch_tree("nightly"),
'name': f'{name_prefix}binary_libtorchvision_ops_android_upload' "name": f"{name_prefix}binary_libtorchvision_ops_android_upload",
} }
jobs.append({'binary_android_upload': upload_job}) jobs.append({"binary_android_upload": upload_job})
else: else:
jobs.append({'binary_android_build': build_job}) jobs.append({"binary_android_build": build_job})
return indent(indentation, jobs) return indent(indentation, jobs)
...@@ -330,15 +335,17 @@ if __name__ == "__main__": ...@@ -330,15 +335,17 @@ if __name__ == "__main__":
env = jinja2.Environment( env = jinja2.Environment(
loader=jinja2.FileSystemLoader(d), loader=jinja2.FileSystemLoader(d),
lstrip_blocks=True, lstrip_blocks=True,
autoescape=select_autoescape(enabled_extensions=('html', 'xml')), autoescape=select_autoescape(enabled_extensions=("html", "xml")),
keep_trailing_newline=True, keep_trailing_newline=True,
) )
with open(os.path.join(d, 'config.yml'), 'w') as f: with open(os.path.join(d, "config.yml"), "w") as f:
f.write(env.get_template('config.yml.in').render( f.write(
build_workflows=build_workflows, env.get_template("config.yml.in").render(
unittest_workflows=unittest_workflows, build_workflows=build_workflows,
cmake_workflows=cmake_workflows, unittest_workflows=unittest_workflows,
ios_workflows=ios_workflows, cmake_workflows=cmake_workflows,
android_workflows=android_workflows, ios_workflows=ios_workflows,
)) android_workflows=android_workflows,
)
)
...@@ -42,7 +42,6 @@ import signal ...@@ -42,7 +42,6 @@ import signal
import subprocess import subprocess
import sys import sys
import traceback import traceback
from functools import partial from functools import partial
try: try:
...@@ -51,7 +50,7 @@ except ImportError: ...@@ -51,7 +50,7 @@ except ImportError:
DEVNULL = open(os.devnull, "wb") DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu' DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
class ExitStatus: class ExitStatus:
...@@ -75,14 +74,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None): ...@@ -75,14 +74,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
# os.walk() supports trimming down the dnames list # os.walk() supports trimming down the dnames list
# by modifying it in-place, # by modifying it in-place,
# to avoid unnecessary directory listings. # to avoid unnecessary directory listings.
dnames[:] = [ dnames[:] = [x for x in dnames if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)]
x for x in dnames fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)]
if
not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
]
fpaths = [
x for x in fpaths if not fnmatch.fnmatch(x, pattern)
]
for f in fpaths: for f in fpaths:
ext = os.path.splitext(f)[1][1:] ext = os.path.splitext(f)[1][1:]
if ext in extensions: if ext in extensions:
...@@ -95,11 +88,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None): ...@@ -95,11 +88,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def make_diff(file, original, reformatted): def make_diff(file, original, reformatted):
return list( return list(
difflib.unified_diff( difflib.unified_diff(
original, original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3
reformatted, )
fromfile='{}\t(original)'.format(file), )
tofile='{}\t(reformatted)'.format(file),
n=3))
class DiffError(Exception): class DiffError(Exception):
...@@ -122,13 +113,12 @@ def run_clang_format_diff_wrapper(args, file): ...@@ -122,13 +113,12 @@ def run_clang_format_diff_wrapper(args, file):
except DiffError: except DiffError:
raise raise
except Exception as e: except Exception as e:
raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__, raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e)
e), e)
def run_clang_format_diff(args, file): def run_clang_format_diff(args, file):
try: try:
with io.open(file, 'r', encoding='utf-8') as f: with io.open(file, "r", encoding="utf-8") as f:
original = f.readlines() original = f.readlines()
except IOError as exc: except IOError as exc:
raise DiffError(str(exc)) raise DiffError(str(exc))
...@@ -153,17 +143,10 @@ def run_clang_format_diff(args, file): ...@@ -153,17 +143,10 @@ def run_clang_format_diff(args, file):
try: try:
proc = subprocess.Popen( proc = subprocess.Popen(
invocation, invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8"
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
encoding='utf-8')
except OSError as exc:
raise DiffError(
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(invocation), exc
)
) )
except OSError as exc:
raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc))
proc_stdout = proc.stdout proc_stdout = proc.stdout
proc_stderr = proc.stderr proc_stderr = proc.stderr
...@@ -182,30 +165,30 @@ def run_clang_format_diff(args, file): ...@@ -182,30 +165,30 @@ def run_clang_format_diff(args, file):
def bold_red(s): def bold_red(s):
return '\x1b[1m\x1b[31m' + s + '\x1b[0m' return "\x1b[1m\x1b[31m" + s + "\x1b[0m"
def colorize(diff_lines): def colorize(diff_lines):
def bold(s): def bold(s):
return '\x1b[1m' + s + '\x1b[0m' return "\x1b[1m" + s + "\x1b[0m"
def cyan(s): def cyan(s):
return '\x1b[36m' + s + '\x1b[0m' return "\x1b[36m" + s + "\x1b[0m"
def green(s): def green(s):
return '\x1b[32m' + s + '\x1b[0m' return "\x1b[32m" + s + "\x1b[0m"
def red(s): def red(s):
return '\x1b[31m' + s + '\x1b[0m' return "\x1b[31m" + s + "\x1b[0m"
for line in diff_lines: for line in diff_lines:
if line[:4] in ['--- ', '+++ ']: if line[:4] in ["--- ", "+++ "]:
yield bold(line) yield bold(line)
elif line.startswith('@@ '): elif line.startswith("@@ "):
yield cyan(line) yield cyan(line)
elif line.startswith('+'): elif line.startswith("+"):
yield green(line) yield green(line)
elif line.startswith('-'): elif line.startswith("-"):
yield red(line) yield red(line)
else: else:
yield line yield line
...@@ -218,7 +201,7 @@ def print_diff(diff_lines, use_color): ...@@ -218,7 +201,7 @@ def print_diff(diff_lines, use_color):
def print_trouble(prog, message, use_colors): def print_trouble(prog, message, use_colors):
error_text = 'error:' error_text = "error:"
if use_colors: if use_colors:
error_text = bold_red(error_text) error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr) print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
...@@ -227,45 +210,37 @@ def print_trouble(prog, message, use_colors): ...@@ -227,45 +210,37 @@ def print_trouble(prog, message, use_colors):
def main(): def main():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--clang-format-executable', "--clang-format-executable",
metavar='EXECUTABLE', metavar="EXECUTABLE",
help='path to the clang-format executable', help="path to the clang-format executable",
default='clang-format') default="clang-format",
parser.add_argument( )
'--extensions',
help='comma separated list of file extensions (default: {})'.format(
DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS)
parser.add_argument( parser.add_argument(
'-r', "--extensions",
'--recursive', help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS),
action='store_true', default=DEFAULT_EXTENSIONS,
help='run recursively over directories') )
parser.add_argument('files', metavar='file', nargs='+') parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories")
parser.add_argument("files", metavar="file", nargs="+")
parser.add_argument("-q", "--quiet", action="store_true")
parser.add_argument( parser.add_argument(
'-q', "-j",
'--quiet', metavar="N",
action='store_true')
parser.add_argument(
'-j',
metavar='N',
type=int, type=int,
default=0, default=0,
help='run N clang-format jobs in parallel' help="run N clang-format jobs in parallel" " (default number of cpus + 1)",
' (default number of cpus + 1)') )
parser.add_argument( parser.add_argument(
'--color', "--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)"
default='auto', )
choices=['auto', 'always', 'never'],
help='show colored diff (default: auto)')
parser.add_argument( parser.add_argument(
'-e', "-e",
'--exclude', "--exclude",
metavar='PATTERN', metavar="PATTERN",
action='append', action="append",
default=[], default=[],
help='exclude paths matching the given glob-like pattern(s)' help="exclude paths matching the given glob-like pattern(s)" " from recursive search",
' from recursive search') )
args = parser.parse_args() args = parser.parse_args()
...@@ -282,10 +257,10 @@ def main(): ...@@ -282,10 +257,10 @@ def main():
colored_stdout = False colored_stdout = False
colored_stderr = False colored_stderr = False
if args.color == 'always': if args.color == "always":
colored_stdout = True colored_stdout = True
colored_stderr = True colored_stderr = True
elif args.color == 'auto': elif args.color == "auto":
colored_stdout = sys.stdout.isatty() colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty() colored_stderr = sys.stderr.isatty()
...@@ -298,19 +273,15 @@ def main(): ...@@ -298,19 +273,15 @@ def main():
except OSError as e: except OSError as e:
print_trouble( print_trouble(
parser.prog, parser.prog,
"Command '{}' failed to start: {}".format( "Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e),
subprocess.list2cmdline(version_invocation), e
),
use_colors=colored_stderr, use_colors=colored_stderr,
) )
return ExitStatus.TROUBLE return ExitStatus.TROUBLE
retcode = ExitStatus.SUCCESS retcode = ExitStatus.SUCCESS
files = list_files( files = list_files(
args.files, args.files, recursive=args.recursive, exclude=args.exclude, extensions=args.extensions.split(",")
recursive=args.recursive, )
exclude=args.exclude,
extensions=args.extensions.split(','))
if not files: if not files:
return return
...@@ -327,8 +298,7 @@ def main(): ...@@ -327,8 +298,7 @@ def main():
pool = None pool = None
else: else:
pool = multiprocessing.Pool(njobs) pool = multiprocessing.Pool(njobs)
it = pool.imap_unordered( it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files)
partial(run_clang_format_diff_wrapper, args), files)
while True: while True:
try: try:
outs, errs = next(it) outs, errs = next(it)
...@@ -359,5 +329,5 @@ def main(): ...@@ -359,5 +329,5 @@ def main():
return retcode return retcode
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())
...@@ -20,8 +20,8 @@ jobs: ...@@ -20,8 +20,8 @@ jobs:
with: with:
python-version: 3.6 python-version: 3.6
- name: Upgrade pip - name: Upgrade system packages
run: python -m pip install --upgrade pip run: python -m pip install --upgrade pip setuptools wheel
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v2
...@@ -30,7 +30,7 @@ jobs: ...@@ -30,7 +30,7 @@ jobs:
run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install torchvision - name: Install torchvision
run: pip install -e . run: pip install --no-build-isolation --editable .
- name: Install all optional dataset requirements - name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests run: pip install scipy pandas pycocotools lmdb requests
......
repos: repos:
- repo: https://github.com/omnilib/ufmt
rev: v1.3.0
hooks:
- id: ufmt
additional_dependencies:
- black == 21.9b0
- usort == 0.6.4
- repo: https://gitlab.com/pycqa/flake8 - repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2 rev: 3.9.2
hooks: hooks:
......
...@@ -5,11 +5,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile ...@@ -5,11 +5,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
print(torch.__version__) print(torch.__version__)
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=True, pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
box_score_thresh=0.7, )
rpn_post_nms_top_n_test=100,
rpn_score_thresh=0.4,
rpn_pre_nms_top_n_test=150)
model.eval() model.eval()
script_model = torch.jit.script(model) script_model = torch.jit.script(model)
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
# import sys # import sys
# sys.path.insert(0, os.path.abspath('.')) # sys.path.insert(0, os.path.abspath('.'))
import torchvision
import pytorch_sphinx_theme import pytorch_sphinx_theme
import torchvision
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
...@@ -33,24 +33,24 @@ import pytorch_sphinx_theme ...@@ -33,24 +33,24 @@ import pytorch_sphinx_theme
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
'sphinx.ext.doctest', "sphinx.ext.doctest",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.todo', "sphinx.ext.todo",
'sphinx.ext.mathjax', "sphinx.ext.mathjax",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.duration', "sphinx.ext.duration",
'sphinx_gallery.gen_gallery', "sphinx_gallery.gen_gallery",
'sphinx_copybutton', "sphinx_copybutton",
] ]
sphinx_gallery_conf = { sphinx_gallery_conf = {
'examples_dirs': '../../gallery/', # path to your example scripts "examples_dirs": "../../gallery/", # path to your example scripts
'gallery_dirs': 'auto_examples', # path to where to save gallery generated output "gallery_dirs": "auto_examples", # path to where to save gallery generated output
'backreferences_dir': 'gen_modules/backreferences', "backreferences_dir": "gen_modules/backreferences",
'doc_module': ('torchvision',), "doc_module": ("torchvision",),
} }
napoleon_use_ivar = True napoleon_use_ivar = True
...@@ -59,22 +59,22 @@ napoleon_google_docstring = True ...@@ -59,22 +59,22 @@ napoleon_google_docstring = True
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
source_suffix = { source_suffix = {
'.rst': 'restructuredtext', ".rst": "restructuredtext",
} }
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = 'Torchvision' project = "Torchvision"
copyright = '2017-present, Torch Contributors' copyright = "2017-present, Torch Contributors"
author = 'Torch Contributors' author = "Torch Contributors"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
...@@ -82,10 +82,10 @@ author = 'Torch Contributors' ...@@ -82,10 +82,10 @@ author = 'Torch Contributors'
# #
# The short X.Y version. # The short X.Y version.
# TODO: change to [:2] at v1.0 # TODO: change to [:2] at v1.0
version = 'main (' + torchvision.__version__ + ' )' version = "main (" + torchvision.__version__ + " )"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected # TODO: verify this works as expected
release = 'main' release = "main"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
...@@ -100,7 +100,7 @@ language = None ...@@ -100,7 +100,7 @@ language = None
exclude_patterns = [] exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing. # If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True todo_include_todos = True
...@@ -111,7 +111,7 @@ todo_include_todos = True ...@@ -111,7 +111,7 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'pytorch_sphinx_theme' html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
...@@ -119,30 +119,30 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] ...@@ -119,30 +119,30 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# documentation. # documentation.
# #
html_theme_options = { html_theme_options = {
'collapse_navigation': False, "collapse_navigation": False,
'display_version': True, "display_version": True,
'logo_only': True, "logo_only": True,
'pytorch_project': 'docs', "pytorch_project": "docs",
'navigation_with_keys': True, "navigation_with_keys": True,
'analytics_id': 'UA-117752657-2', "analytics_id": "UA-117752657-2",
} }
html_logo = '_static/img/pytorch-logo-dark.svg' html_logo = "_static/img/pytorch-logo-dark.svg"
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed # TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed
html_css_files = [ html_css_files = [
'css/custom_torchvision.css', "css/custom_torchvision.css",
] ]
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc' htmlhelp_basename = "PyTorchdoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
...@@ -150,15 +150,12 @@ latex_elements = { ...@@ -150,15 +150,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
...@@ -169,8 +166,7 @@ latex_elements = { ...@@ -169,8 +166,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'pytorch.tex', 'torchvision Documentation', (master_doc, "pytorch.tex", "torchvision Documentation", "Torch Contributors", "manual"),
'Torch Contributors', 'manual'),
] ]
...@@ -178,10 +174,7 @@ latex_documents = [ ...@@ -178,10 +174,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "torchvision", "torchvision Documentation", [author], 1)]
(master_doc, 'torchvision', 'torchvision Documentation',
[author], 1)
]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
...@@ -190,27 +183,33 @@ man_pages = [ ...@@ -190,27 +183,33 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'torchvision', 'torchvision Documentation', (
author, 'torchvision', 'One line description of project.', master_doc,
'Miscellaneous'), "torchvision",
"torchvision Documentation",
author,
"torchvision",
"One line description of project.",
"Miscellaneous",
),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/', None), "python": ("https://docs.python.org/", None),
'torch': ('https://pytorch.org/docs/stable/', None), "torch": ("https://pytorch.org/docs/stable/", None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None), "numpy": ("http://docs.scipy.org/doc/numpy/", None),
'PIL': ('https://pillow.readthedocs.io/en/stable/', None), "PIL": ("https://pillow.readthedocs.io/en/stable/", None),
'matplotlib': ('https://matplotlib.org/stable/', None), "matplotlib": ("https://matplotlib.org/stable/", None),
} }
# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043 # See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes from sphinx import addnodes
from sphinx.util.docfields import TypedField
def patched_make_field(self, types, domain, items, **kw): def patched_make_field(self, types, domain, items, **kw):
...@@ -220,40 +219,39 @@ def patched_make_field(self, types, domain, items, **kw): ...@@ -220,40 +219,39 @@ def patched_make_field(self, types, domain, items, **kw):
# type: (list, unicode, tuple) -> nodes.field # noqa: F821 # type: (list, unicode, tuple) -> nodes.field # noqa: F821
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong)) # addnodes.literal_strong))
if fieldarg in types: if fieldarg in types:
par += nodes.Text(' (') par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be # NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to # inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved # inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg) fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype) typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int') typename = typename.replace("int", "python:int")
typename = typename.replace('long', 'python:long') typename = typename.replace("long", "python:long")
typename = typename.replace('float', 'python:float') typename = typename.replace("float", "python:float")
typename = typename.replace('type', 'python:type') typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw))
addnodes.literal_emphasis, **kw))
else: else:
par += fieldtype par += fieldtype
par += nodes.Text(')') par += nodes.Text(")")
par += nodes.Text(' -- ') par += nodes.Text(" -- ")
par += content par += content
return par return par
fieldname = nodes.field_name('', self.label) fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse: if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0] fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content) bodynode = handle_item(fieldarg, content)
else: else:
bodynode = self.list_type() bodynode = self.list_type()
for fieldarg, content in items: for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content)) bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode) fieldbody = nodes.field_body("", bodynode)
return nodes.field('', fieldname, fieldbody) return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field TypedField.make_field = patched_make_field
...@@ -286,4 +284,4 @@ def inject_minigalleries(app, what, name, obj, options, lines): ...@@ -286,4 +284,4 @@ def inject_minigalleries(app, what, name, obj, options, lines):
def setup(app): def setup(app):
app.connect('autodoc-process-docstring', inject_minigalleries) app.connect("autodoc-process-docstring", inject_minigalleries)
# Optional list of dependencies required by the package # Optional list of dependencies required by the package
dependencies = ['torch'] dependencies = ["torch"]
# classification # classification
from torchvision.models.alexnet import alexnet from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3 from torchvision.models.efficientnet import (
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\ efficientnet_b0,
resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 efficientnet_b1,
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 efficientnet_b2,
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn efficientnet_b3,
efficientnet_b4,
efficientnet_b5,
efficientnet_b6,
efficientnet_b7,
)
from torchvision.models.googlenet import googlenet from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 from torchvision.models.inception import inception_v3
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3
from torchvision.models.mobilenetv2 import mobilenet_v2 from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ from torchvision.models.regnet import (
mnasnet1_3 regnet_y_400mf,
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \ regnet_y_800mf,
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 regnet_y_1_6gf,
from torchvision.models.regnet import regnet_y_400mf, regnet_y_800mf, \ regnet_y_3_2gf,
regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, \ regnet_y_8gf,
regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, \ regnet_y_16gf,
regnet_x_16gf, regnet_x_32gf regnet_y_32gf,
regnet_x_400mf,
regnet_x_800mf,
regnet_x_1_6gf,
regnet_x_3_2gf,
regnet_x_8gf,
regnet_x_16gf,
regnet_x_32gf,
)
from torchvision.models.resnet import (
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
resnext50_32x4d,
resnext101_32x8d,
wide_resnet50_2,
wide_resnet101_2,
)
# segmentation # segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ from torchvision.models.segmentation import (
deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large fcn_resnet50,
fcn_resnet101,
deeplabv3_resnet50,
deeplabv3_resnet101,
deeplabv3_mobilenet_v3_large,
lraspp_mobilenet_v3_large,
)
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
...@@ -5,11 +5,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile ...@@ -5,11 +5,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
print(torch.__version__) print(torch.__version__)
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=True, pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150
box_score_thresh=0.7, )
rpn_post_nms_top_n_test=100,
rpn_score_thresh=0.4,
rpn_pre_nms_top_n_test=150)
model.eval() model.eval()
script_model = torch.jit.script(model) script_model = torch.jit.script(model)
......
...@@ -2,46 +2,63 @@ ...@@ -2,46 +2,63 @@
"""Helper script to package wheels and relocate binaries.""" """Helper script to package wheels and relocate binaries."""
# Standard library imports
import os
import io
import sys
import glob import glob
import shutil
import zipfile
import hashlib import hashlib
import io
# Standard library imports
import os
import os.path as osp
import platform import platform
import shutil
import subprocess import subprocess
import os.path as osp import sys
import zipfile
from base64 import urlsafe_b64encode from base64 import urlsafe_b64encode
# Third party imports # Third party imports
if sys.platform == 'linux': if sys.platform == "linux":
from auditwheel.lddtree import lddtree from auditwheel.lddtree import lddtree
from wheel.bdist_wheel import get_abi_tag from wheel.bdist_wheel import get_abi_tag
ALLOWLIST = { ALLOWLIST = {
'libgcc_s.so.1', 'libstdc++.so.6', 'libm.so.6', "libgcc_s.so.1",
'libdl.so.2', 'librt.so.1', 'libc.so.6', "libstdc++.so.6",
'libnsl.so.1', 'libutil.so.1', 'libpthread.so.0', "libm.so.6",
'libresolv.so.2', 'libX11.so.6', 'libXext.so.6', "libdl.so.2",
'libXrender.so.1', 'libICE.so.6', 'libSM.so.6', "librt.so.1",
'libGL.so.1', 'libgobject-2.0.so.0', 'libgthread-2.0.so.0', "libc.so.6",
'libglib-2.0.so.0', 'ld-linux-x86-64.so.2', 'ld-2.17.so' "libnsl.so.1",
"libutil.so.1",
"libpthread.so.0",
"libresolv.so.2",
"libX11.so.6",
"libXext.so.6",
"libXrender.so.1",
"libICE.so.6",
"libSM.so.6",
"libGL.so.1",
"libgobject-2.0.so.0",
"libgthread-2.0.so.0",
"libglib-2.0.so.0",
"ld-linux-x86-64.so.2",
"ld-2.17.so",
} }
WINDOWS_ALLOWLIST = { WINDOWS_ALLOWLIST = {
'MSVCP140.dll', 'KERNEL32.dll', "MSVCP140.dll",
'VCRUNTIME140_1.dll', 'VCRUNTIME140.dll', "KERNEL32.dll",
'api-ms-win-crt-heap-l1-1-0.dll', "VCRUNTIME140_1.dll",
'api-ms-win-crt-runtime-l1-1-0.dll', "VCRUNTIME140.dll",
'api-ms-win-crt-stdio-l1-1-0.dll', "api-ms-win-crt-heap-l1-1-0.dll",
'api-ms-win-crt-filesystem-l1-1-0.dll', "api-ms-win-crt-runtime-l1-1-0.dll",
'api-ms-win-crt-string-l1-1-0.dll', "api-ms-win-crt-stdio-l1-1-0.dll",
'api-ms-win-crt-environment-l1-1-0.dll', "api-ms-win-crt-filesystem-l1-1-0.dll",
'api-ms-win-crt-math-l1-1-0.dll', "api-ms-win-crt-string-l1-1-0.dll",
'api-ms-win-crt-convert-l1-1-0.dll' "api-ms-win-crt-environment-l1-1-0.dll",
"api-ms-win-crt-math-l1-1-0.dll",
"api-ms-win-crt-convert-l1-1-0.dll",
} }
...@@ -64,20 +81,18 @@ def rehash(path, blocksize=1 << 20): ...@@ -64,20 +81,18 @@ def rehash(path, blocksize=1 << 20):
"""Return (hash, length) for path using hashlib.sha256()""" """Return (hash, length) for path using hashlib.sha256()"""
h = hashlib.sha256() h = hashlib.sha256()
length = 0 length = 0
with open(path, 'rb') as f: with open(path, "rb") as f:
for block in read_chunks(f, size=blocksize): for block in read_chunks(f, size=blocksize):
length += len(block) length += len(block)
h.update(block) h.update(block)
digest = 'sha256=' + urlsafe_b64encode( digest = "sha256=" + urlsafe_b64encode(h.digest()).decode("latin1").rstrip("=")
h.digest()
).decode('latin1').rstrip('=')
# unicode/str python2 issues # unicode/str python2 issues
return (digest, str(length)) # type: ignore return (digest, str(length)) # type: ignore
def unzip_file(file, dest): def unzip_file(file, dest):
"""Decompress zip `file` into directory `dest`.""" """Decompress zip `file` into directory `dest`."""
with zipfile.ZipFile(file, 'r') as zip_ref: with zipfile.ZipFile(file, "r") as zip_ref:
zip_ref.extractall(dest) zip_ref.extractall(dest)
...@@ -88,8 +103,7 @@ def is_program_installed(basename): ...@@ -88,8 +103,7 @@ def is_program_installed(basename):
On macOS systems, a .app is considered installed if On macOS systems, a .app is considered installed if
it exists. it exists.
""" """
if (sys.platform == 'darwin' and basename.endswith('.app') and if sys.platform == "darwin" and basename.endswith(".app") and osp.exists(basename):
osp.exists(basename)):
return basename return basename
for path in os.environ["PATH"].split(os.pathsep): for path in os.environ["PATH"].split(os.pathsep):
...@@ -105,9 +119,9 @@ def find_program(basename): ...@@ -105,9 +119,9 @@ def find_program(basename):
(return None if not found) (return None if not found)
""" """
names = [basename] names = [basename]
if os.name == 'nt': if os.name == "nt":
# Windows platforms # Windows platforms
extensions = ('.exe', '.bat', '.cmd', '.dll') extensions = (".exe", ".bat", ".cmd", ".dll")
if not basename.endswith(extensions): if not basename.endswith(extensions):
names = [basename + ext for ext in extensions] + [basename] names = [basename + ext for ext in extensions] + [basename]
for name in names: for name in names:
...@@ -118,19 +132,18 @@ def find_program(basename): ...@@ -118,19 +132,18 @@ def find_program(basename):
def patch_new_path(library_path, new_dir): def patch_new_path(library_path, new_dir):
library = osp.basename(library_path) library = osp.basename(library_path)
name, *rest = library.split('.') name, *rest = library.split(".")
rest = '.'.join(rest) rest = ".".join(rest)
hash_id = hashlib.sha256(library_path.encode('utf-8')).hexdigest()[:8] hash_id = hashlib.sha256(library_path.encode("utf-8")).hexdigest()[:8]
new_name = '.'.join([name, hash_id, rest]) new_name = ".".join([name, hash_id, rest])
return osp.join(new_dir, new_name) return osp.join(new_dir, new_name)
def find_dll_dependencies(dumpbin, binary): def find_dll_dependencies(dumpbin, binary):
out = subprocess.run([dumpbin, "/dependents", binary], out = subprocess.run([dumpbin, "/dependents", binary], stdout=subprocess.PIPE)
stdout=subprocess.PIPE) out = out.stdout.strip().decode("utf-8")
out = out.stdout.strip().decode('utf-8') start_index = out.find("dependencies:") + len("dependencies:")
start_index = out.find('dependencies:') + len('dependencies:') end_index = out.find("Summary")
end_index = out.find('Summary')
dlls = out[start_index:end_index].strip() dlls = out[start_index:end_index].strip()
dlls = dlls.split(os.linesep) dlls = dlls.split(os.linesep)
dlls = [dll.strip() for dll in dlls] dlls = [dll.strip() for dll in dlls]
...@@ -145,13 +158,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -145,13 +158,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
rename and copy them into the wheel while updating their respective rpaths. rename and copy them into the wheel while updating their respective rpaths.
""" """
print('Relocating {0}'.format(binary)) print("Relocating {0}".format(binary))
binary_path = osp.join(output_library, binary) binary_path = osp.join(output_library, binary)
ld_tree = lddtree(binary_path) ld_tree = lddtree(binary_path)
tree_libs = ld_tree['libs'] tree_libs = ld_tree["libs"]
binary_queue = [(n, binary) for n in ld_tree['needed']] binary_queue = [(n, binary) for n in ld_tree["needed"]]
binary_paths = {binary: binary_path} binary_paths = {binary: binary_path}
binary_dependencies = {} binary_dependencies = {}
...@@ -160,13 +173,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -160,13 +173,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
library_info = tree_libs[library] library_info = tree_libs[library]
print(library) print(library)
if library_info['path'] is None: if library_info["path"] is None:
print('Omitting {0}'.format(library)) print("Omitting {0}".format(library))
continue continue
if library in ALLOWLIST: if library in ALLOWLIST:
# Omit glibc/gcc/system libraries # Omit glibc/gcc/system libraries
print('Omitting {0}'.format(library)) print("Omitting {0}".format(library))
continue continue
parent_dependencies = binary_dependencies.get(parent, []) parent_dependencies = binary_dependencies.get(parent, [])
...@@ -176,11 +189,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -176,11 +189,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if library in binary_paths: if library in binary_paths:
continue continue
binary_paths[library] = library_info['path'] binary_paths[library] = library_info["path"]
binary_queue += [(n, library) for n in library_info['needed']] binary_queue += [(n, library) for n in library_info["needed"]]
print('Copying dependencies to wheel directory') print("Copying dependencies to wheel directory")
new_libraries_path = osp.join(output_dir, 'torchvision.libs') new_libraries_path = osp.join(output_dir, "torchvision.libs")
os.makedirs(new_libraries_path) os.makedirs(new_libraries_path)
new_names = {binary: binary_path} new_names = {binary: binary_path}
...@@ -189,11 +202,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -189,11 +202,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if library != binary: if library != binary:
library_path = binary_paths[library] library_path = binary_paths[library]
new_library_path = patch_new_path(library_path, new_libraries_path) new_library_path = patch_new_path(library_path, new_libraries_path)
print('{0} -> {1}'.format(library, new_library_path)) print("{0} -> {1}".format(library, new_library_path))
shutil.copyfile(library_path, new_library_path) shutil.copyfile(library_path, new_library_path)
new_names[library] = new_library_path new_names[library] = new_library_path
print('Updating dependency names by new files') print("Updating dependency names by new files")
for library in binary_paths: for library in binary_paths:
if library != binary: if library != binary:
if library not in binary_dependencies: if library not in binary_dependencies:
...@@ -202,59 +215,26 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): ...@@ -202,59 +215,26 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
new_library_name = new_names[library] new_library_name = new_names[library]
for dep in library_dependencies: for dep in library_dependencies:
new_dep = osp.basename(new_names[dep]) new_dep = osp.basename(new_names[dep])
print('{0}: {1} -> {2}'.format(library, dep, new_dep)) print("{0}: {1} -> {2}".format(library, dep, new_dep))
subprocess.check_output( subprocess.check_output(
[ [patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path
patchelf, )
'--replace-needed',
dep, print("Updating library rpath")
new_dep, subprocess.check_output([patchelf, "--set-rpath", "$ORIGIN", new_library_name], cwd=new_libraries_path)
new_library_name
], subprocess.check_output([patchelf, "--print-rpath", new_library_name], cwd=new_libraries_path)
cwd=new_libraries_path)
print('Updating library rpath')
subprocess.check_output(
[
patchelf,
'--set-rpath',
"$ORIGIN",
new_library_name
],
cwd=new_libraries_path)
subprocess.check_output(
[
patchelf,
'--print-rpath',
new_library_name
],
cwd=new_libraries_path)
print("Update library dependencies") print("Update library dependencies")
library_dependencies = binary_dependencies[binary] library_dependencies = binary_dependencies[binary]
for dep in library_dependencies: for dep in library_dependencies:
new_dep = osp.basename(new_names[dep]) new_dep = osp.basename(new_names[dep])
print('{0}: {1} -> {2}'.format(binary, dep, new_dep)) print("{0}: {1} -> {2}".format(binary, dep, new_dep))
subprocess.check_output( subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library)
[
patchelf, print("Update library rpath")
'--replace-needed',
dep,
new_dep,
binary
],
cwd=output_library)
print('Update library rpath')
subprocess.check_output( subprocess.check_output(
[ [patchelf, "--set-rpath", "$ORIGIN:$ORIGIN/../torchvision.libs", binary_path], cwd=output_library
patchelf,
'--set-rpath',
"$ORIGIN:$ORIGIN/../torchvision.libs",
binary_path
],
cwd=output_library
) )
...@@ -265,7 +245,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): ...@@ -265,7 +245,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
Given a shared library, find the transitive closure of its dependencies, Given a shared library, find the transitive closure of its dependencies,
rename and copy them into the wheel. rename and copy them into the wheel.
""" """
print('Relocating {0}'.format(binary)) print("Relocating {0}".format(binary))
binary_path = osp.join(output_library, binary) binary_path = osp.join(output_library, binary)
library_dlls = find_dll_dependencies(dumpbin, binary_path) library_dlls = find_dll_dependencies(dumpbin, binary_path)
...@@ -275,19 +255,19 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): ...@@ -275,19 +255,19 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
while binary_queue != []: while binary_queue != []:
library, parent = binary_queue.pop(0) library, parent = binary_queue.pop(0)
if library in WINDOWS_ALLOWLIST or library.startswith('api-ms-win'): if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"):
print('Omitting {0}'.format(library)) print("Omitting {0}".format(library))
continue continue
library_path = find_program(library) library_path = find_program(library)
if library_path is None: if library_path is None:
print('{0} not found'.format(library)) print("{0} not found".format(library))
continue continue
if osp.basename(osp.dirname(library_path)) == 'system32': if osp.basename(osp.dirname(library_path)) == "system32":
continue continue
print('{0}: {1}'.format(library, library_path)) print("{0}: {1}".format(library, library_path))
parent_dependencies = binary_dependencies.get(parent, []) parent_dependencies = binary_dependencies.get(parent, [])
parent_dependencies.append(library) parent_dependencies.append(library)
binary_dependencies[parent] = parent_dependencies binary_dependencies[parent] = parent_dependencies
...@@ -299,55 +279,56 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): ...@@ -299,55 +279,56 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
downstream_dlls = find_dll_dependencies(dumpbin, library_path) downstream_dlls = find_dll_dependencies(dumpbin, library_path)
binary_queue += [(n, library) for n in downstream_dlls] binary_queue += [(n, library) for n in downstream_dlls]
print('Copying dependencies to wheel directory') print("Copying dependencies to wheel directory")
package_dir = osp.join(output_dir, 'torchvision') package_dir = osp.join(output_dir, "torchvision")
for library in binary_paths: for library in binary_paths:
if library != binary: if library != binary:
library_path = binary_paths[library] library_path = binary_paths[library]
new_library_path = osp.join(package_dir, library) new_library_path = osp.join(package_dir, library)
print('{0} -> {1}'.format(library, new_library_path)) print("{0} -> {1}".format(library, new_library_path))
shutil.copyfile(library_path, new_library_path) shutil.copyfile(library_path, new_library_path)
def compress_wheel(output_dir, wheel, wheel_dir, wheel_name): def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
"""Create RECORD file and compress wheel distribution.""" """Create RECORD file and compress wheel distribution."""
print('Update RECORD file in wheel') print("Update RECORD file in wheel")
dist_info = glob.glob(osp.join(output_dir, '*.dist-info'))[0] dist_info = glob.glob(osp.join(output_dir, "*.dist-info"))[0]
record_file = osp.join(dist_info, 'RECORD') record_file = osp.join(dist_info, "RECORD")
with open(record_file, 'w') as f: with open(record_file, "w") as f:
for root, _, files in os.walk(output_dir): for root, _, files in os.walk(output_dir):
for this_file in files: for this_file in files:
full_file = osp.join(root, this_file) full_file = osp.join(root, this_file)
rel_file = osp.relpath(full_file, output_dir) rel_file = osp.relpath(full_file, output_dir)
if full_file == record_file: if full_file == record_file:
f.write('{0},,\n'.format(rel_file)) f.write("{0},,\n".format(rel_file))
else: else:
digest, size = rehash(full_file) digest, size = rehash(full_file)
f.write('{0},{1},{2}\n'.format(rel_file, digest, size)) f.write("{0},{1},{2}\n".format(rel_file, digest, size))
print('Compressing wheel') print("Compressing wheel")
base_wheel_name = osp.join(wheel_dir, wheel_name) base_wheel_name = osp.join(wheel_dir, wheel_name)
shutil.make_archive(base_wheel_name, 'zip', output_dir) shutil.make_archive(base_wheel_name, "zip", output_dir)
os.remove(wheel) os.remove(wheel)
shutil.move('{0}.zip'.format(base_wheel_name), wheel) shutil.move("{0}.zip".format(base_wheel_name), wheel)
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
def patch_linux(): def patch_linux():
# Get patchelf location # Get patchelf location
patchelf = find_program('patchelf') patchelf = find_program("patchelf")
if patchelf is None: if patchelf is None:
raise FileNotFoundError('Patchelf was not found in the system, please' raise FileNotFoundError(
' make sure that is available on the PATH.') "Patchelf was not found in the system, please" " make sure that is available on the PATH."
)
# Find wheel # Find wheel
print('Finding wheels...') print("Finding wheels...")
wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl')) wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl"))
output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process') output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process")
image_binary = 'image.so' image_binary = "image.so"
video_binary = 'video_reader.so' video_binary = "video_reader.so"
torchvision_binaries = [image_binary, video_binary] torchvision_binaries = [image_binary, video_binary]
for wheel in wheels: for wheel in wheels:
if osp.exists(output_dir): if osp.exists(output_dir):
...@@ -355,37 +336,37 @@ def patch_linux(): ...@@ -355,37 +336,37 @@ def patch_linux():
os.makedirs(output_dir) os.makedirs(output_dir)
print('Unzipping wheel...') print("Unzipping wheel...")
wheel_file = osp.basename(wheel) wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel) wheel_dir = osp.dirname(wheel)
print('{0}'.format(wheel_file)) print("{0}".format(wheel_file))
wheel_name, _ = osp.splitext(wheel_file) wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir) unzip_file(wheel, output_dir)
print('Finding ELF dependencies...') print("Finding ELF dependencies...")
output_library = osp.join(output_dir, 'torchvision') output_library = osp.join(output_dir, "torchvision")
for binary in torchvision_binaries: for binary in torchvision_binaries:
if osp.exists(osp.join(output_library, binary)): if osp.exists(osp.join(output_library, binary)):
relocate_elf_library( relocate_elf_library(patchelf, output_dir, output_library, binary)
patchelf, output_dir, output_library, binary)
compress_wheel(output_dir, wheel, wheel_dir, wheel_name) compress_wheel(output_dir, wheel, wheel_dir, wheel_name)
def patch_win(): def patch_win():
# Get dumpbin location # Get dumpbin location
dumpbin = find_program('dumpbin') dumpbin = find_program("dumpbin")
if dumpbin is None: if dumpbin is None:
raise FileNotFoundError('Dumpbin was not found in the system, please' raise FileNotFoundError(
' make sure that is available on the PATH.') "Dumpbin was not found in the system, please" " make sure that is available on the PATH."
)
# Find wheel # Find wheel
print('Finding wheels...') print("Finding wheels...")
wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl')) wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl"))
output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process') output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process")
image_binary = 'image.pyd' image_binary = "image.pyd"
video_binary = 'video_reader.pyd' video_binary = "video_reader.pyd"
torchvision_binaries = [image_binary, video_binary] torchvision_binaries = [image_binary, video_binary]
for wheel in wheels: for wheel in wheels:
if osp.exists(output_dir): if osp.exists(output_dir):
...@@ -393,25 +374,24 @@ def patch_win(): ...@@ -393,25 +374,24 @@ def patch_win():
os.makedirs(output_dir) os.makedirs(output_dir)
print('Unzipping wheel...') print("Unzipping wheel...")
wheel_file = osp.basename(wheel) wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel) wheel_dir = osp.dirname(wheel)
print('{0}'.format(wheel_file)) print("{0}".format(wheel_file))
wheel_name, _ = osp.splitext(wheel_file) wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir) unzip_file(wheel, output_dir)
print('Finding DLL/PE dependencies...') print("Finding DLL/PE dependencies...")
output_library = osp.join(output_dir, 'torchvision') output_library = osp.join(output_dir, "torchvision")
for binary in torchvision_binaries: for binary in torchvision_binaries:
if osp.exists(osp.join(output_library, binary)): if osp.exists(osp.join(output_library, binary)):
relocate_dll_library( relocate_dll_library(dumpbin, output_dir, output_library, binary)
dumpbin, output_dir, output_library, binary)
compress_wheel(output_dir, wheel, wheel_dir, wheel_name) compress_wheel(output_dir, wheel, wheel_dir, wheel_name)
if __name__ == '__main__': if __name__ == "__main__":
if sys.platform == 'linux': if sys.platform == "linux":
patch_linux() patch_linux()
elif sys.platform == 'win32': elif sys.platform == "win32":
patch_win() patch_win()
[tool.usort]
first_party_detection = false
[tool.black]
line-length = 120
target-version = ["py36"]
[tool.ufmt]
excludes = [
"gallery",
]
...@@ -4,8 +4,15 @@ from torchvision.transforms.functional import InterpolationMode ...@@ -4,8 +4,15 @@ from torchvision.transforms.functional import InterpolationMode
class ClassificationPresetTrain: class ClassificationPresetTrain:
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, def __init__(
auto_augment_policy=None, random_erase_prob=0.0): self,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size)] trans = [transforms.RandomResizedCrop(crop_size)]
if hflip_prob > 0: if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) trans.append(transforms.RandomHorizontalFlip(hflip_prob))
...@@ -17,11 +24,13 @@ class ClassificationPresetTrain: ...@@ -17,11 +24,13 @@ class ClassificationPresetTrain:
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([ trans.extend(
transforms.PILToTensor(), [
transforms.ConvertImageDtype(torch.float), transforms.PILToTensor(),
transforms.Normalize(mean=mean, std=std), transforms.ConvertImageDtype(torch.float),
]) transforms.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0: if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob)) trans.append(transforms.RandomErasing(p=random_erase_prob))
...@@ -32,16 +41,24 @@ class ClassificationPresetTrain: ...@@ -32,16 +41,24 @@ class ClassificationPresetTrain:
class ClassificationPresetEval: class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), def __init__(
interpolation=InterpolationMode.BILINEAR): self,
crop_size,
self.transforms = transforms.Compose([ resize_size=256,
transforms.Resize(resize_size, interpolation=interpolation), mean=(0.485, 0.456, 0.406),
transforms.CenterCrop(crop_size), std=(0.229, 0.224, 0.225),
transforms.PILToTensor(), interpolation=InterpolationMode.BILINEAR,
transforms.ConvertImageDtype(torch.float), ):
transforms.Normalize(mean=mean, std=std),
]) self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
def __call__(self, img): def __call__(self, img):
return self.transforms(img) return self.transforms(img)
This diff is collapsed.
import copy
import datetime import datetime
import os import os
import time import time
import copy
import torch import torch
import torch.quantization
import torch.utils.data import torch.utils.data
from torch import nn
import torchvision import torchvision
import torch.quantization
import utils import utils
from torch import nn
from train import train_one_epoch, evaluate, load_data from train import train_one_epoch, evaluate, load_data
...@@ -20,8 +20,7 @@ def main(args): ...@@ -20,8 +20,7 @@ def main(args):
print(args) print(args)
if args.post_training_quantize and args.distributed: if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed " raise RuntimeError("Post training quantization example should not be performed " "on distributed mode")
"on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels # Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines: if args.backend not in torch.backends.quantized.supported_engines:
...@@ -33,17 +32,17 @@ def main(args): ...@@ -33,17 +32,17 @@ def main(args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
train_dir = os.path.join(args.data_path, 'train') train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, 'val') val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
sampler=train_sampler, num_workers=args.workers, pin_memory=True) )
data_loader_test = torch.utils.data.DataLoader( data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.eval_batch_size, dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
sampler=test_sampler, num_workers=args.workers, pin_memory=True) )
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
...@@ -59,12 +58,10 @@ def main(args): ...@@ -59,12 +58,10 @@ def main(args):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
weight_decay=args.weight_decay) )
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
step_size=args.lr_step_size,
gamma=args.lr_gamma)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
model_without_ddp = model model_without_ddp = model
...@@ -73,21 +70,19 @@ def main(args): ...@@ -73,21 +70,19 @@ def main(args):
model_without_ddp = model.module model_without_ddp = model.module
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint['epoch'] + 1 args.start_epoch = checkpoint["epoch"] + 1
if args.post_training_quantize: if args.post_training_quantize:
# perform calibration on a subset of the training dataset # perform calibration on a subset of the training dataset
# for that, create a subset of the training dataset # for that, create a subset of the training dataset
ds = torch.utils.data.Subset( ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
dataset,
indices=list(range(args.batch_size * args.num_calibration_batches)))
data_loader_calibration = torch.utils.data.DataLoader( data_loader_calibration = torch.utils.data.DataLoader(
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
pin_memory=True) )
model.eval() model.eval()
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qconfig(args.backend) model.qconfig = torch.quantization.get_default_qconfig(args.backend)
...@@ -97,10 +92,9 @@ def main(args): ...@@ -97,10 +92,9 @@ def main(args):
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
torch.quantization.convert(model, inplace=True) torch.quantization.convert(model, inplace=True)
if args.output_dir: if args.output_dir:
print('Saving quantized model') print("Saving quantized model")
if utils.is_main_process(): if utils.is_main_process():
torch.save(model.state_dict(), os.path.join(args.output_dir, torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
'quantized_post_train_model.pth'))
print("Evaluating post-training quantized model") print("Evaluating post-training quantized model")
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
return return
...@@ -115,107 +109,103 @@ def main(args): ...@@ -115,107 +109,103 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
print('Starting training for epoch', epoch) print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
args.print_freq)
lr_scheduler.step() lr_scheduler.step()
with torch.no_grad(): with torch.no_grad():
if epoch >= args.num_observer_update_epochs: if epoch >= args.num_observer_update_epochs:
print('Disabling observer for subseq epochs, epoch = ', epoch) print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.quantization.disable_observer) model.apply(torch.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs: if epoch >= args.num_batch_norm_update_epochs:
print('Freezing BN for subseq epochs, epoch = ', epoch) print("Freezing BN for subseq epochs, epoch = ", epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print('Evaluate QAT model') print("Evaluate QAT model")
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval() quantized_eval_model.eval()
quantized_eval_model.to(torch.device('cpu')) quantized_eval_model.to(torch.device("cpu"))
torch.quantization.convert(quantized_eval_model, inplace=True) torch.quantization.convert(quantized_eval_model, inplace=True)
print('Evaluate Quantized model') print("Evaluate Quantized model")
evaluate(quantized_eval_model, criterion, data_loader_test, evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
device=torch.device('cpu'))
model.train() model.train()
if args.output_dir: if args.output_dir:
checkpoint = { checkpoint = {
'model': model_without_ddp.state_dict(), "model": model_without_ddp.state_dict(),
'eval_model': quantized_eval_model.state_dict(), "eval_model": quantized_eval_model.state_dict(),
'optimizer': optimizer.state_dict(), "optimizer": optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(), "lr_scheduler": lr_scheduler.state_dict(),
'epoch': epoch, "epoch": epoch,
'args': args} "args": args,
utils.save_on_master( }
checkpoint, utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
utils.save_on_master( print("Saving models after epoch ", epoch)
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
print('Saving models after epoch ', epoch)
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str)) print("Training time {}".format(total_time_str))
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
import argparse import argparse
parser = argparse.ArgumentParser(description='PyTorch Quantized Classification Training', add_help=add_help)
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
parser.add_argument('--data-path',
default='/datasets01/imagenet_full_size/061417/', parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
help='dataset') parser.add_argument("--model", default="mobilenet_v2", help="model")
parser.add_argument('--model', parser.add_argument("--backend", default="qnnpack", help="fbgemm or qnnpack")
default='mobilenet_v2', parser.add_argument("--device", default="cuda", help="device")
help='model')
parser.add_argument('--backend', parser.add_argument("-b", "--batch-size", default=32, type=int, help="batch size for calibration/training")
default='qnnpack', parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
help='fbgemm or qnnpack') parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument('--device', parser.add_argument(
default='cuda', "--num-observer-update-epochs",
help='device') default=4,
type=int,
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar="N",
help='batch size for calibration/training') help="number of total epochs to update observers",
parser.add_argument('--eval-batch-size', default=128, type=int, )
help='batch size for evaluation') parser.add_argument(
parser.add_argument('--epochs', default=90, type=int, metavar='N', "--num-batch-norm-update-epochs",
help='number of total epochs to run') default=3,
parser.add_argument('--num-observer-update-epochs', type=int,
default=4, type=int, metavar='N', metavar="N",
help='number of total epochs to update observers') help="number of total epochs to update batch norm stats",
parser.add_argument('--num-batch-norm-update-epochs', default=3, )
type=int, metavar='N', parser.add_argument(
help='number of total epochs to update batch norm stats') "--num-calibration-batches",
parser.add_argument('--num-calibration-batches', default=32,
default=32, type=int, metavar='N', type=int,
help='number of batches of training set for \ metavar="N",
observer calibration ') help="number of batches of training set for \
observer calibration ",
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', )
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', parser.add_argument(
default=0.0001, type=float, "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
help='initial learning rate') )
parser.add_argument('--momentum', parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
default=0.9, type=float, metavar='M', parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
help='momentum') parser.add_argument(
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, "--wd",
metavar='W', help='weight decay (default: 1e-4)', "--weight-decay",
dest='weight_decay') default=1e-4,
parser.add_argument('--lr-step-size', default=30, type=int, type=float,
help='decrease lr every step-size epochs') metavar="W",
parser.add_argument('--lr-gamma', default=0.1, type=float, help="weight decay (default: 1e-4)",
help='decrease lr by a factor of lr-gamma') dest="weight_decay",
parser.add_argument('--print-freq', default=10, type=int, )
help='print frequency') parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument("--output-dir", default=".", help="path where to save")
help='start epoch') parser.add_argument("--resume", default="", help="resume from checkpoint")
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument( parser.add_argument(
"--cache-dataset", "--cache-dataset",
dest="cache_dataset", dest="cache_dataset",
...@@ -243,11 +233,8 @@ def get_args_parser(add_help=True): ...@@ -243,11 +233,8 @@ def get_args_parser(add_help=True):
) )
# distributed training parameters # distributed training parameters
parser.add_argument('--world-size', default=1, type=int, parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
help='number of distributed processes') parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
parser.add_argument('--dist-url',
default='env://',
help='url used to set up distributed training')
return parser return parser
......
import math import math
import torch
from typing import Tuple from typing import Tuple
import torch
from torch import Tensor from torch import Tensor
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
...@@ -19,9 +19,7 @@ class RandomMixup(torch.nn.Module): ...@@ -19,9 +19,7 @@ class RandomMixup(torch.nn.Module):
inplace (bool): boolean to make this transform inplace. Default set to False. inplace (bool): boolean to make this transform inplace. Default set to False.
""" """
def __init__(self, num_classes: int, def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__() super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes." assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero." assert alpha > 0, "Alpha param can't be zero."
...@@ -45,7 +43,7 @@ class RandomMixup(torch.nn.Module): ...@@ -45,7 +43,7 @@ class RandomMixup(torch.nn.Module):
elif target.ndim != 1: elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point(): elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype))
elif target.dtype != torch.int64: elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
...@@ -74,12 +72,12 @@ class RandomMixup(torch.nn.Module): ...@@ -74,12 +72,12 @@ class RandomMixup(torch.nn.Module):
return batch, target return batch, target
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += 'num_classes={num_classes}' s += "num_classes={num_classes}"
s += ', p={p}' s += ", p={p}"
s += ', alpha={alpha}' s += ", alpha={alpha}"
s += ', inplace={inplace}' s += ", inplace={inplace}"
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
...@@ -97,9 +95,7 @@ class RandomCutmix(torch.nn.Module): ...@@ -97,9 +95,7 @@ class RandomCutmix(torch.nn.Module):
inplace (bool): boolean to make this transform inplace. Default set to False. inplace (bool): boolean to make this transform inplace. Default set to False.
""" """
def __init__(self, num_classes: int, def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__() super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes." assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero." assert alpha > 0, "Alpha param can't be zero."
...@@ -123,7 +119,7 @@ class RandomCutmix(torch.nn.Module): ...@@ -123,7 +119,7 @@ class RandomCutmix(torch.nn.Module):
elif target.ndim != 1: elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point(): elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype))
elif target.dtype != torch.int64: elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
...@@ -166,10 +162,10 @@ class RandomCutmix(torch.nn.Module): ...@@ -166,10 +162,10 @@ class RandomCutmix(torch.nn.Module):
return batch, target return batch, target
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += 'num_classes={num_classes}' s += "num_classes={num_classes}"
s += ', p={p}' s += ", p={p}"
s += ', alpha={alpha}' s += ", alpha={alpha}"
s += ', inplace={inplace}' s += ", inplace={inplace}"
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
from collections import defaultdict, deque, OrderedDict
import copy import copy
import datetime import datetime
import errno
import hashlib import hashlib
import os
import time import time
from collections import defaultdict, deque, OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import errno
import os
class SmoothedValue(object): class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
...@@ -34,7 +34,7 @@ class SmoothedValue(object): ...@@ -34,7 +34,7 @@ class SmoothedValue(object):
""" """
if not is_dist_avail_and_initialized(): if not is_dist_avail_and_initialized():
return return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier() dist.barrier()
dist.all_reduce(t) dist.all_reduce(t)
t = t.tolist() t = t.tolist()
...@@ -65,11 +65,8 @@ class SmoothedValue(object): ...@@ -65,11 +65,8 @@ class SmoothedValue(object):
def __str__(self): def __str__(self):
return self.fmt.format( return self.fmt.format(
median=self.median, median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
avg=self.avg, )
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object): class MetricLogger(object):
...@@ -89,15 +86,12 @@ class MetricLogger(object): ...@@ -89,15 +86,12 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format( raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
type(self).__name__, attr))
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append( loss_str.append("{}: {}".format(name, str(meter)))
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
...@@ -110,31 +104,28 @@ class MetricLogger(object): ...@@ -110,31 +104,28 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None): def log_every(self, iterable, print_freq, header=None):
i = 0 i = 0
if not header: if not header:
header = '' header = ""
start_time = time.time() start_time = time.time()
end = time.time() end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}') iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ':' + str(len(str(len(iterable)))) + 'd' space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available(): if torch.cuda.is_available():
log_msg = self.delimiter.join([ log_msg = self.delimiter.join(
header, [
'[{0' + space_fmt + '}/{1}]', header,
'eta: {eta}', "[{0" + space_fmt + "}/{1}]",
'{meters}', "eta: {eta}",
'time: {time}', "{meters}",
'data: {data}', "time: {time}",
'max mem: {memory:.0f}' "data: {data}",
]) "max mem: {memory:.0f}",
]
)
else: else:
log_msg = self.delimiter.join([ log_msg = self.delimiter.join(
header, [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
'[{0' + space_fmt + '}/{1}]', )
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0 MB = 1024.0 * 1024.0
for obj in iterable: for obj in iterable:
data_time.update(time.time() - end) data_time.update(time.time() - end)
...@@ -144,21 +135,28 @@ class MetricLogger(object): ...@@ -144,21 +135,28 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available(): if torch.cuda.is_available():
print(log_msg.format( print(
i, len(iterable), eta=eta_string, log_msg.format(
meters=str(self), i,
time=str(iter_time), data=str(data_time), len(iterable),
memory=torch.cuda.max_memory_allocated() / MB)) eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else: else:
print(log_msg.format( print(
i, len(iterable), eta=eta_string, log_msg.format(
meters=str(self), i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
time=str(iter_time), data=str(data_time))) )
)
i += 1 i += 1
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str)) print("{} Total time: {}".format(header, total_time_str))
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
...@@ -167,9 +165,9 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): ...@@ -167,9 +165,9 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_ `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA. is used to compute the EMA.
""" """
def __init__(self, model, decay, device='cpu'):
ema_avg = (lambda avg_model_param, model_param, num_averaged: def __init__(self, model, decay, device="cpu"):
decay * avg_model_param + (1 - decay) * model_param) ema_avg = lambda avg_model_param, model_param, num_averaged: decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg) super().__init__(model, device, ema_avg)
def update_parameters(self, model): def update_parameters(self, model):
...@@ -179,8 +177,7 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): ...@@ -179,8 +177,7 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
if self.n_averaged == 0: if self.n_averaged == 0:
p_swa.detach().copy_(p_model_) p_swa.detach().copy_(p_model_)
else: else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)))
self.n_averaged.to(device)))
self.n_averaged += 1 self.n_averaged += 1
...@@ -216,10 +213,11 @@ def setup_for_distributed(is_master): ...@@ -216,10 +213,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process This function disables printing when not in master process
""" """
import builtins as __builtin__ import builtins as __builtin__
builtin_print = __builtin__.print builtin_print = __builtin__.print
def print(*args, **kwargs): def print(*args, **kwargs):
force = kwargs.pop('force', False) force = kwargs.pop("force", False)
if is_master or force: if is_master or force:
builtin_print(*args, **kwargs) builtin_print(*args, **kwargs)
...@@ -256,28 +254,28 @@ def save_on_master(*args, **kwargs): ...@@ -256,28 +254,28 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args): def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE']) args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ['LOCAL_RANK']) args.gpu = int(os.environ["LOCAL_RANK"])
elif 'SLURM_PROCID' in os.environ: elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ['SLURM_PROCID']) args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count() args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"): elif hasattr(args, "rank"):
pass pass
else: else:
print('Not using distributed mode') print("Not using distributed mode")
args.distributed = False args.distributed = False
return return
args.distributed = True args.distributed = True
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl' args.dist_backend = "nccl"
print('| distributed init (rank {}): {}'.format( print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
world_size=args.world_size, rank=args.rank) )
setup_for_distributed(args.rank == 0) setup_for_distributed(args.rank == 0)
...@@ -300,9 +298,7 @@ def average_checkpoints(inputs): ...@@ -300,9 +298,7 @@ def average_checkpoints(inputs):
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
state = torch.load( state = torch.load(
f, f,
map_location=( map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
) )
# Copies over the settings from the first checkpoint # Copies over the settings from the first checkpoint
if new_state is None: if new_state is None:
...@@ -336,7 +332,7 @@ def average_checkpoints(inputs): ...@@ -336,7 +332,7 @@ def average_checkpoints(inputs):
return new_state return new_state
def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True): def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
""" """
This method can be used to prepare weights files for new models. It receives as This method can be used to prepare weights files for new models. It receives as
input a model architecture and a checkpoint from the training script and produces input a model architecture and a checkpoint from the training script and produces
...@@ -382,7 +378,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T ...@@ -382,7 +378,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T
# Deep copy to avoid side-effects on the model object. # Deep copy to avoid side-effects on the model object.
model = copy.deepcopy(model) model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Load the weights to the model to validate that everything works # Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc) # and remove unnecessary weights (such as auxiliaries, etc)
......
...@@ -5,10 +5,9 @@ from contextlib import redirect_stdout ...@@ -5,10 +5,9 @@ from contextlib import redirect_stdout
import numpy as np import numpy as np
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
import torch import torch
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import utils import utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
class CocoEvaluator: class CocoEvaluator:
...@@ -104,8 +103,7 @@ class CocoEvaluator: ...@@ -104,8 +103,7 @@ class CocoEvaluator:
labels = prediction["labels"].tolist() labels = prediction["labels"].tolist()
rles = [ rles = [
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
for mask in masks
] ]
for rle in rles: for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8") rle["counts"] = rle["counts"].decode("utf-8")
...@@ -141,7 +139,7 @@ class CocoEvaluator: ...@@ -141,7 +139,7 @@ class CocoEvaluator:
{ {
"image_id": original_id, "image_id": original_id,
"category_id": labels[k], "category_id": labels[k],
'keypoints': keypoint, "keypoints": keypoint,
"score": scores[k], "score": scores[k],
} }
for k, keypoint in enumerate(keypoints) for k, keypoint in enumerate(keypoints)
......
import copy import copy
import os import os
from PIL import Image
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import transforms as T
from PIL import Image
from pycocotools import mask as coco_mask from pycocotools import mask as coco_mask
from pycocotools.coco import COCO from pycocotools.coco import COCO
import transforms as T
class FilterAndRemapCocoCategories(object): class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True): def __init__(self, categories, remap=True):
...@@ -56,7 +54,7 @@ class ConvertCocoPolysToMask(object): ...@@ -56,7 +54,7 @@ class ConvertCocoPolysToMask(object):
anno = target["annotations"] anno = target["annotations"]
anno = [obj for obj in anno if obj['iscrowd'] == 0] anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno] boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing # guard against no boxes via resizing
...@@ -147,7 +145,7 @@ def convert_to_coco_api(ds): ...@@ -147,7 +145,7 @@ def convert_to_coco_api(ds):
coco_ds = COCO() coco_ds = COCO()
# annotation IDs need to start at 1, not 0, see torchvision issue #1530 # annotation IDs need to start at 1, not 0, see torchvision issue #1530
ann_id = 1 ann_id = 1
dataset = {'images': [], 'categories': [], 'annotations': []} dataset = {"images": [], "categories": [], "annotations": []}
categories = set() categories = set()
for img_idx in range(len(ds)): for img_idx in range(len(ds)):
# find better way to get target # find better way to get target
...@@ -155,41 +153,41 @@ def convert_to_coco_api(ds): ...@@ -155,41 +153,41 @@ def convert_to_coco_api(ds):
img, targets = ds[img_idx] img, targets = ds[img_idx]
image_id = targets["image_id"].item() image_id = targets["image_id"].item()
img_dict = {} img_dict = {}
img_dict['id'] = image_id img_dict["id"] = image_id
img_dict['height'] = img.shape[-2] img_dict["height"] = img.shape[-2]
img_dict['width'] = img.shape[-1] img_dict["width"] = img.shape[-1]
dataset['images'].append(img_dict) dataset["images"].append(img_dict)
bboxes = targets["boxes"] bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2] bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist() bboxes = bboxes.tolist()
labels = targets['labels'].tolist() labels = targets["labels"].tolist()
areas = targets['area'].tolist() areas = targets["area"].tolist()
iscrowd = targets['iscrowd'].tolist() iscrowd = targets["iscrowd"].tolist()
if 'masks' in targets: if "masks" in targets:
masks = targets['masks'] masks = targets["masks"]
# make masks Fortran contiguous for coco_mask # make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if 'keypoints' in targets: if "keypoints" in targets:
keypoints = targets['keypoints'] keypoints = targets["keypoints"]
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes) num_objs = len(bboxes)
for i in range(num_objs): for i in range(num_objs):
ann = {} ann = {}
ann['image_id'] = image_id ann["image_id"] = image_id
ann['bbox'] = bboxes[i] ann["bbox"] = bboxes[i]
ann['category_id'] = labels[i] ann["category_id"] = labels[i]
categories.add(labels[i]) categories.add(labels[i])
ann['area'] = areas[i] ann["area"] = areas[i]
ann['iscrowd'] = iscrowd[i] ann["iscrowd"] = iscrowd[i]
ann['id'] = ann_id ann["id"] = ann_id
if 'masks' in targets: if "masks" in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy()) ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if 'keypoints' in targets: if "keypoints" in targets:
ann['keypoints'] = keypoints[i] ann["keypoints"] = keypoints[i]
ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
dataset['annotations'].append(ann) dataset["annotations"].append(ann)
ann_id += 1 ann_id += 1
dataset['categories'] = [{'id': i} for i in sorted(categories)] dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco_ds.dataset = dataset coco_ds.dataset = dataset
coco_ds.createIndex() coco_ds.createIndex()
return coco_ds return coco_ds
...@@ -220,7 +218,7 @@ class CocoDetection(torchvision.datasets.CocoDetection): ...@@ -220,7 +218,7 @@ class CocoDetection(torchvision.datasets.CocoDetection):
return img, target return img, target
def get_coco(root, image_set, transforms, mode='instances'): def get_coco(root, image_set, transforms, mode="instances"):
anno_file_template = "{}_{}2017.json" anno_file_template = "{}_{}2017.json"
PATHS = { PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
......
import math import math
import sys import sys
import time import time
import torch
import torch
import torchvision.models.detection.mask_rcnn import torchvision.models.detection.mask_rcnn
from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator
import utils import utils
from coco_eval import CocoEvaluator
from coco_utils import get_coco_api_from_dataset
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = 'Epoch: [{}]'.format(epoch) header = "Epoch: [{}]".format(epoch)
lr_scheduler = None lr_scheduler = None
if epoch == 0: if epoch == 0:
warmup_factor = 1. / 1000 warmup_factor = 1.0 / 1000
warmup_iters = min(1000, len(data_loader) - 1) warmup_iters = min(1000, len(data_loader) - 1)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor, lr_scheduler = torch.optim.lr_scheduler.LinearLR(
total_iters=warmup_iters) optimizer, start_factor=warmup_factor, total_iters=warmup_iters
)
for images, targets in metric_logger.log_every(data_loader, print_freq, header): for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images) images = list(image.to(device) for image in images)
...@@ -76,7 +76,7 @@ def evaluate(model, data_loader, device): ...@@ -76,7 +76,7 @@ def evaluate(model, data_loader, device):
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = "Test:"
coco = get_coco_api_from_dataset(data_loader.dataset) coco = get_coco_api_from_dataset(data_loader.dataset)
iou_types = _get_iou_types(model) iou_types = _get_iou_types(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment