Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
import asyncio
import sys
from pathlib import Path
from time import perf_counter
from urllib.parse import urlsplit
import aiofiles
import aiohttp
from torchvision import models
from tqdm.asyncio import tqdm
async def main(download_root):
download_root.mkdir(parents=True, exist_ok=True)
urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))}
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
await tqdm.gather(*[download(download_root, session, url) for url in urls])
async def download(download_root, session, url):
response = await session.get(url, params=dict(source="ci"))
assert response.ok
file_name = Path(urlsplit(url).path).name
async with aiofiles.open(download_root / file_name, "wb") as f:
async for data in response.content.iter_any():
await f.write(data)
if __name__ == "__main__":
download_root = (
(Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve()
)
print(f"Downloading model weights to {download_root}")
start = perf_counter()
asyncio.get_event_loop().run_until_complete(main(download_root))
stop = perf_counter()
minutes, seconds = divmod(stop - start, 60)
print(f"Download took {minutes:2.0f}m {seconds:2.0f}s")
#!/bin/bash
if [ -z $1 ]
then
echo "Commit hash is required to be passed when running this script."
echo "./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>"
exit 1
fi
commit_hash=$1
if [ -z $2 ]
then
echo "Fork name is required to be passed when running this script."
echo "./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>"
exit 1
fi
fork_name=$2
if [ -z $3 ]
then
fork_main_branch="main"
else
fork_main_branch=$3
fi
from_branch="fbsync"
git stash
git checkout $from_branch
git pull
# Add random prefix in the new branch name to keep it unique per run
prefix=$RANDOM
IFS='
'
for line in $(git log --pretty=oneline "$commit_hash"..HEAD)
do
if [[ $line != *\[fbsync\]* ]]
then
echo "Parsing $line"
hash=$(echo $line | cut -f1 -d' ')
git checkout $fork_main_branch
git checkout -B cherrypick_${prefix}_${hash}
git cherry-pick -x "$hash"
git push $fork_name cherrypick_${prefix}_${hash}
git checkout $from_branch
fi
done
echo "Please review the PRs, add [FBCode->GH] prefix in the title and publish them."
# In[1]:
import pandas as pd
# In[2]:
data_filename = "data.json"
df = pd.read_json(data_filename).T
df.tail()
# In[3]:
all_labels = {lbl for labels in df["labels"] for lbl in labels}
all_labels
# In[4]:
# Add one column per label
for label in all_labels:
df[label] = df["labels"].apply(lambda labels_list: label in labels_list)
df.head()
# In[5]:
# Add a clean "module" column. It contains tuples since PRs can have more than one module.
# Maybe we should include "topics" in that column as well?
all_modules = { # mapping: full name -> clean name
label: "".join(label.split(" ")[1:]) for label in all_labels if label.startswith("module")
}
# We use an ugly loop, but whatever ¯\_(ツ)_/¯
df["module"] = [[] for _ in range(len(df))]
for i, row in df.iterrows():
for full_name, clean_name in all_modules.items():
if full_name in row["labels"]:
row["module"].append(clean_name)
df["module"] = df.module.apply(tuple)
df.head()
# In[6]:
mod_df = df.set_index("module").sort_index()
mod_df.tail()
# In[7]:
# All improvement PRs
mod_df[mod_df["enhancement"]].head()
# In[8]:
# improvement f module
# note: don't filter module name on the index as the index contain tuples with non-exclusive values
# Use the boolean column instead
mod_df[mod_df["enhancement"] & mod_df["module: transforms"]]
# In[9]:
def format_prs(mod_df, exclude_prototype=True):
out = []
for idx, row in mod_df.iterrows():
if exclude_prototype and "prototype" in row and row["prototype"]:
continue
modules = idx
# Put "documentation" and "tests" first for sorting to be dece
for last_module in ("documentation", "tests"):
if last_module in modules:
modules = [m for m in modules if m != last_module] + [last_module]
module = f"[{', '.join(modules)}]"
module = module.replace("referencescripts", "reference scripts")
module = module.replace("code", "reference scripts")
out.append(f"{module} {row['title']}")
return "\n".join(out)
# In[10]:
included_prs = pd.DataFrame()
# If labels are accurate, this shouhld generate most of the release notes already
# We keep track of the included PRs to figure out which ones are missing
for section_title, module_idx in (
("Backward-incompatible changes", "bc-breaking"),
("Deprecations", "deprecation"),
("New Features", "new feature"),
("Improvements", "enhancement"),
("Bug Fixes", "bug"),
("Code Quality", "code quality"),
):
if module_idx in mod_df:
print(f"## {section_title}")
print()
tmp_df = mod_df[mod_df[module_idx]]
included_prs = pd.concat([included_prs, tmp_df])
print(format_prs(tmp_df))
print()
# In[11]:
# Missing PRs are these ones... classify them manually
missing_prs = pd.concat([mod_df, included_prs]).drop_duplicates(subset="pr_number", keep=False)
print(format_prs(missing_prs))
# In[12]:
# Generate list of contributors
print()
print("## Contributors")
previous_release = "c35d3855ccbfa6a36e6ae6337a1f2c721c1f1e78"
current_release = "5181a854d8b127cf465cd22a67c1b5aaf6ccae05"
print(
f"{{ git shortlog -s {previous_release}..{current_release} | cut -f2- & git log -s {previous_release}..{current_release} | grep Co-authored | cut -f2- -d: | cut -f1 -d\\< | sed 's/^ *//;s/ *//' ; }} | sort --ignore-case | uniq | tr '\\n' ';' | sed 's/;/, /g;s/,//' | fold -s"
)
# In[13]:
# Utility to extract PR numbers only from multiple lines, useful to bundle all
# the docs changes for example:
import re
s = """
[] Remove unnecessary dependency from macOS/Conda binaries (#8077)
[rocm] [ROCm] remove HCC references (#8070)
"""
print(", ".join(re.findall("(#\\d+)", s)))
import json
import locale
import os
import re
import subprocess
from collections import namedtuple
from os.path import expanduser
import requests
Features = namedtuple(
"Features",
[
"title",
"body",
"pr_number",
"files_changed",
"labels",
],
)
def dict_to_features(dct):
return Features(
title=dct["title"],
body=dct["body"],
pr_number=dct["pr_number"],
files_changed=dct["files_changed"],
labels=dct["labels"],
)
def features_to_dict(features):
return dict(features._asdict())
def run(command):
"""Returns (return-code, stdout, stderr)"""
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
output, err = p.communicate()
rc = p.returncode
enc = locale.getpreferredencoding()
output = output.decode(enc)
err = err.decode(enc)
return rc, output.strip(), err.strip()
def commit_body(commit_hash):
cmd = f"git log -n 1 --pretty=format:%b {commit_hash}"
ret, out, err = run(cmd)
return out if ret == 0 else None
def commit_title(commit_hash):
cmd = f"git log -n 1 --pretty=format:%s {commit_hash}"
ret, out, err = run(cmd)
return out if ret == 0 else None
def commit_files_changed(commit_hash):
cmd = f"git diff-tree --no-commit-id --name-only -r {commit_hash}"
ret, out, err = run(cmd)
return out.split("\n") if ret == 0 else None
def parse_pr_number(body, commit_hash, title):
regex = r"(#[0-9]+)"
matches = re.findall(regex, title)
if len(matches) == 0:
if "revert" not in title.lower() and "updating submodules" not in title.lower():
print(f"[{commit_hash}: {title}] Could not parse PR number, ignoring PR")
return None
if len(matches) > 1:
print(f"[{commit_hash}: {title}] Got two PR numbers, using the last one")
return matches[-1][1:]
return matches[0][1:]
def get_ghstack_token():
pattern = "github_oauth = (.*)"
with open(expanduser("~/.ghstackrc"), "r+") as f:
config = f.read()
matches = re.findall(pattern, config)
if len(matches) == 0:
raise RuntimeError("Can't find a github oauth token")
return matches[0]
token = get_ghstack_token()
headers = {"Authorization": f"token {token}"}
def run_query(query):
request = requests.post("https://api.github.com/graphql", json={"query": query}, headers=headers)
if request.status_code == 200:
return request.json()
else:
raise Exception(f"Query failed to run by returning code of {request.status_code}. {query}")
def gh_labels(pr_number):
query = f"""
{{
repository(owner: "pytorch", name: "vision") {{
pullRequest(number: {pr_number}) {{
labels(first: 10) {{
edges {{
node {{
name
}}
}}
}}
}}
}}
}}
"""
query = run_query(query)
edges = query["data"]["repository"]["pullRequest"]["labels"]["edges"]
return [edge["node"]["name"] for edge in edges]
def get_features(commit_hash, return_dict=False):
title, body, files_changed = (
commit_title(commit_hash),
commit_body(commit_hash),
commit_files_changed(commit_hash),
)
pr_number = parse_pr_number(body, commit_hash, title)
labels = []
if pr_number is not None:
labels = gh_labels(pr_number)
result = Features(title, body, pr_number, files_changed, labels)
if return_dict:
return features_to_dict(result)
return result
class CommitDataCache:
def __init__(self, path="results/data.json"):
self.path = path
self.data = {}
if os.path.exists(path):
self.data = self.read_from_disk()
def get(self, commit):
if commit not in self.data.keys():
# Fetch and cache the data
self.data[commit] = get_features(commit)
self.write_to_disk()
return self.data[commit]
def read_from_disk(self):
with open(self.path) as f:
data = json.load(f)
data = {commit: dict_to_features(dct) for commit, dct in data.items()}
return data
def write_to_disk(self):
data = {commit: features._asdict() for commit, features in self.data.items()}
with open(self.path, "w") as f:
json.dump(data, f)
def get_commits_between(base_version, new_version):
cmd = f"git merge-base {base_version} {new_version}"
rc, merge_base, _ = run(cmd)
assert rc == 0
# Returns a list of something like
# b33e38ec47 Allow a higher-precision step type for Vec256::arange (#34555)
cmd = f"git log --reverse --oneline {merge_base}..{new_version}"
rc, commits, _ = run(cmd)
assert rc == 0
log_lines = commits.split("\n")
hashes, titles = zip(*[log_line.split(" ", 1) for log_line in log_lines])
return hashes, titles
def convert_to_dataframes(feature_list):
import pandas as pd
df = pd.DataFrame.from_records(feature_list, columns=Features._fields)
return df
def main(base_version, new_version):
hashes, titles = get_commits_between(base_version, new_version)
cdc = CommitDataCache("data.json")
for idx, commit in enumerate(hashes):
if idx % 10 == 0:
print(f"{idx} / {len(hashes)}")
cdc.get(commit)
return cdc
if __name__ == "__main__":
# d = get_features('2ab93592529243862ce8ad5b6acf2628ef8d0dc8')
# print(d)
# hashes, titles = get_commits_between("tags/v0.9.0", "fc852f3b39fe25dd8bf1dedee8f19ea04aa84c15")
# Usage: change the tags below accordingly to the current release, then save the json with
# cdc.write_to_disk().
# Then you can use classify_prs.py (as a notebook)
# to open the json and generate the release notes semi-automatically.
cdc = main("tags/v0.9.0", "fc852f3b39fe25dd8bf1dedee8f19ea04aa84c15")
from IPython import embed
embed()
......@@ -2,14 +2,21 @@
universal=1
[metadata]
license_file = LICENSE
license_files = LICENSE
[pep8]
max-line-length = 120
[flake8]
# note: we ignore all 501s (line too long) anyway as they're taken care of by black
max-line-length = 120
ignore = F401,E402,F403,W503,W504,F821
ignore = E203, E402, W503, W504, F821, E501, B, C4, EXE
per-file-ignores =
__init__.py: F401, F403, F405
./hubconf.py: F401
torchvision/models/mobilenet.py: F401, F403
torchvision/models/quantization/mobilenet.py: F401, F403
test/smoke_test.py: F401
exclude = venv
[pydocstyle]
......
import os
import io
import sys
from setuptools import setup, find_packages
from pkg_resources import parse_version, get_distribution, DistributionNotFound
import subprocess
import distutils.command.clean
import distutils.spawn
import glob
import os
import shutil
import subprocess
import sys
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from torch.utils.hipify import hipify_python
from pkg_resources import DistributionNotFound, get_distribution, parse_version
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDA_HOME, CUDAExtension
def read(*names, **kwargs):
with io.open(
os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")
) as fp:
with open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
return fp.read()
......@@ -31,60 +26,61 @@ def get_dist(pkgname):
cwd = os.path.dirname(os.path.abspath(__file__))
version_txt = os.path.join(cwd, 'version.txt')
with open(version_txt, 'r') as f:
version_txt = os.path.join(cwd, "version.txt")
with open(version_txt) as f:
version = f.readline().strip()
sha = 'Unknown'
package_name = 'torchvision'
sha = "Unknown"
package_name = "torchvision"
try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
except Exception:
pass
if os.getenv('BUILD_VERSION'):
version = os.getenv('BUILD_VERSION')
elif sha != 'Unknown':
version += '+' + sha[:7]
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
def write_version_file():
version_path = os.path.join(cwd, 'torchvision', 'version.py')
with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha)))
version_path = os.path.join(cwd, "torchvision", "version.py")
with open(version_path, "w") as f:
f.write(f"__version__ = '{version}'\n")
f.write(f"git_version = {repr(sha)}\n")
f.write("from torchvision.extension import _check_cuda_version\n")
f.write("if _check_cuda_version() > 0:\n")
f.write(" cuda = _check_cuda_version()\n")
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')
pytorch_dep = "torch"
if os.getenv("PYTORCH_VERSION"):
pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
requirements = [
'numpy',
"numpy",
pytorch_dep,
]
pillow_ver = ' >= 5.3.0'
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
pillow_ver = " >= 5.3.0, !=8.3.*"
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
requirements.append(pillow_req + pillow_ver)
def find_library(name, vision_include):
this_dir = os.path.dirname(os.path.abspath(__file__))
build_prefix = os.environ.get('BUILD_PREFIX', None)
build_prefix = os.environ.get("BUILD_PREFIX", None)
is_conda_build = build_prefix is not None
library_found = False
conda_installed = False
lib_folder = None
include_folder = None
library_header = '{0}.h'.format(name)
library_header = f"{name}.h"
# Lookup in TORCHVISION_INCLUDE or in the package file
package_path = [os.path.join(this_dir, 'torchvision')]
package_path = [os.path.join(this_dir, "torchvision")]
for folder in vision_include + package_path:
candidate_path = os.path.join(folder, library_header)
library_found = os.path.exists(candidate_path)
......@@ -92,64 +88,89 @@ def find_library(name, vision_include):
break
if not library_found:
print('Running build on conda-build: {0}'.format(is_conda_build))
print(f"Running build on conda-build: {is_conda_build}")
if is_conda_build:
# Add conda headers/libraries
if os.name == 'nt':
build_prefix = os.path.join(build_prefix, 'Library')
include_folder = os.path.join(build_prefix, 'include')
lib_folder = os.path.join(build_prefix, 'lib')
library_header_path = os.path.join(
include_folder, library_header)
if os.name == "nt":
build_prefix = os.path.join(build_prefix, "Library")
include_folder = os.path.join(build_prefix, "include")
lib_folder = os.path.join(build_prefix, "lib")
library_header_path = os.path.join(include_folder, library_header)
library_found = os.path.isfile(library_header_path)
conda_installed = library_found
else:
# Check if using Anaconda to produce wheels
conda = distutils.spawn.find_executable('conda')
conda = shutil.which("conda")
is_conda = conda is not None
print('Running build on conda: {0}'.format(is_conda))
print(f"Running build on conda: {is_conda}")
if is_conda:
python_executable = sys.executable
py_folder = os.path.dirname(python_executable)
if os.name == 'nt':
env_path = os.path.join(py_folder, 'Library')
if os.name == "nt":
env_path = os.path.join(py_folder, "Library")
else:
env_path = os.path.dirname(py_folder)
lib_folder = os.path.join(env_path, 'lib')
include_folder = os.path.join(env_path, 'include')
library_header_path = os.path.join(
include_folder, library_header)
lib_folder = os.path.join(env_path, "lib")
include_folder = os.path.join(env_path, "include")
library_header_path = os.path.join(include_folder, library_header)
library_found = os.path.isfile(library_header_path)
conda_installed = library_found
if not library_found:
if sys.platform == 'linux':
library_found = os.path.exists('/usr/include/{0}'.format(
library_header))
library_found = library_found or os.path.exists(
'/usr/local/include/{0}'.format(library_header))
if sys.platform == "linux":
library_found = os.path.exists(f"/usr/include/{library_header}")
library_found = library_found or os.path.exists(f"/usr/local/include/{library_header}")
return library_found, conda_installed, include_folder, lib_folder
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
'*.cpp'))
main_file = (
glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = (
glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) +
glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) +
glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp'))
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
)
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}")
force_mps = os.getenv("FORCE_MPS", "0") == "1"
print(f" FORCE_MPS: {force_mps}")
debug_mode = os.getenv("DEBUG", "0") == "1"
print(f" DEBUG: {debug_mode}")
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
print(f" TORCHVISION_USE_PNG: {use_png}")
use_jpeg = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
print(f" TORCHVISION_USE_JPEG: {use_jpeg}")
use_nvjpeg = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
print(f" TORCHVISION_USE_NVJPEG: {use_nvjpeg}")
use_ffmpeg = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
print(f" TORCHVISION_USE_FFMPEG: {use_ffmpeg}")
use_video_codec = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1"
print(f" TORCHVISION_USE_VIDEO_CODEC: {use_video_codec}")
nvcc_flags = os.getenv("NVCC_FLAGS", "")
print(f" NVCC_FLAGS: {nvcc_flags}")
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
if torch.__version__ >= "1.5":
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
is_rocm_pytorch = (torch.version.hip is not None) and (ROCM_HOME is not None)
if is_rocm_pytorch:
from torch.utils.hipify import hipify_python
hipify_python.hipify(
project_directory=this_dir,
output_directory=this_dir,
......@@ -157,68 +178,52 @@ def get_extensions():
show_detailed=True,
is_pytorch_extension=True,
)
source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'hip', '*.hip'))
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip"))
# Copy over additional files
for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
shutil.copy(file, "torchvision/csrc/ops/hip")
else:
source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu'))
source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
sources = main_file + source_cpu
extension = CppExtension
compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
if compile_cpp_tests:
test_dir = os.path.join(this_dir, 'test')
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
source_models = glob.glob(os.path.join(models_dir, '*.cpp'))
test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models
tests_include_dirs = [test_dir, models_dir]
define_macros = []
extra_compile_args = {'cxx': []}
if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) \
or os.getenv('FORCE_CUDA', '0') == '1':
extra_compile_args = {"cxx": []}
if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or force_cuda:
extension = CUDAExtension
sources += source_cuda
if not is_rocm_pytorch:
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
if nvcc_flags == '':
define_macros += [("WITH_CUDA", None)]
if nvcc_flags == "":
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(' ')
nvcc_flags = nvcc_flags.split(" ")
else:
define_macros += [('WITH_HIP', None)]
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps
if sys.platform == 'win32':
define_macros += [('torchvision_EXPORTS', None)]
if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
extra_compile_args["cxx"].append("/MP")
extra_compile_args['cxx'].append('/MP')
debug_mode = os.getenv('DEBUG', '0') == '1'
if debug_mode:
print("Compile in debug mode")
extra_compile_args['cxx'].append("-g")
extra_compile_args['cxx'].append("-O0")
print("Compiling in debug mode")
extra_compile_args["cxx"].append("-g")
extra_compile_args["cxx"].append("-O0")
if "nvcc" in extra_compile_args:
# we have to remove "-OX" and "-g" flag if exists and append
nvcc_flags = extra_compile_args["nvcc"]
extra_compile_args["nvcc"] = [
f for f in nvcc_flags if not ("-O" in f or "-g" in f)
]
extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
extra_compile_args["nvcc"].append("-O0")
extra_compile_args["nvcc"].append("-g")
else:
print("Compiling with debug mode OFF")
extra_compile_args["cxx"].append("-g0")
sources = [os.path.join(extensions_dir, s) for s in sources]
......@@ -226,31 +231,19 @@ def get_extensions():
ext_modules = [
extension(
'torchvision._C',
"torchvision._C",
sorted(sources),
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
if compile_cpp_tests:
ext_modules.append(
extension(
'torchvision._C_tests',
tests,
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
)
# ------------------- Torchvision extra extensions ------------------------
vision_include = os.environ.get('TORCHVISION_INCLUDE', None)
vision_library = os.environ.get('TORCHVISION_LIBRARY', None)
vision_include = (vision_include.split(os.pathsep)
if vision_include is not None else [])
vision_library = (vision_library.split(os.pathsep)
if vision_library is not None else [])
vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
vision_library = os.environ.get("TORCHVISION_LIBRARY", None)
vision_include = vision_include.split(os.pathsep) if vision_include is not None else []
vision_library = vision_library.split(os.pathsep) if vision_library is not None else []
include_dirs += vision_include
library_dirs = vision_library
......@@ -261,158 +254,181 @@ def get_extensions():
image_link_flags = []
# Locating libPNG
libpng = distutils.spawn.find_executable('libpng-config')
pngfix = distutils.spawn.find_executable('pngfix')
libpng = shutil.which("libpng-config")
pngfix = shutil.which("pngfix")
png_found = libpng is not None or pngfix is not None
print('PNG found: {0}'.format(png_found))
if png_found:
use_png = use_png and png_found
if use_png:
print("Found PNG library")
if libpng is not None:
# Linux / Mac
png_version = subprocess.run([libpng, '--version'],
stdout=subprocess.PIPE)
png_version = png_version.stdout.strip().decode('utf-8')
print('libpng version: {0}'.format(png_version))
min_version = "1.6.0"
png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
png_version = png_version.stdout.strip().decode("utf-8")
png_version = parse_version(png_version)
if png_version >= parse_version("1.6.0"):
print('Building torchvision with PNG image support')
png_lib = subprocess.run([libpng, '--libdir'],
stdout=subprocess.PIPE)
png_lib = png_lib.stdout.strip().decode('utf-8')
if 'disabled' not in png_lib:
if png_version >= parse_version(min_version):
print("Building torchvision with PNG image support")
png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE)
png_lib = png_lib.stdout.strip().decode("utf-8")
if "disabled" not in png_lib:
image_library += [png_lib]
png_include = subprocess.run([libpng, '--I_opts'],
stdout=subprocess.PIPE)
png_include = png_include.stdout.strip().decode('utf-8')
_, png_include = png_include.split('-I')
print('libpng include path: {0}'.format(png_include))
png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
png_include = png_include.stdout.strip().decode("utf-8")
_, png_include = png_include.split("-I")
image_include += [png_include]
image_link_flags.append('png')
image_link_flags.append("png")
print(f" libpng version: {png_version}")
print(f" libpng include path: {png_include}")
else:
print('libpng installed version is less than 1.6.0, '
'disabling PNG support')
png_found = False
print("Could not add PNG image support to torchvision:")
print(f" libpng minimum version {min_version}, found {png_version}")
use_png = False
else:
# Windows
png_lib = os.path.join(
os.path.dirname(os.path.dirname(pngfix)), 'lib')
png_include = os.path.join(os.path.dirname(
os.path.dirname(pngfix)), 'include', 'libpng16')
png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib")
png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16")
image_library += [png_lib]
image_include += [png_include]
image_link_flags.append('libpng')
image_link_flags.append("libpng")
else:
print("Building torchvision without PNG image support")
image_macros += [("PNG_FOUND", str(int(use_png)))]
# Locating libjpeg
(jpeg_found, jpeg_conda,
jpeg_include, jpeg_lib) = find_library('jpeglib', vision_include)
print('JPEG found: {0}'.format(jpeg_found))
image_macros += [('PNG_FOUND', str(int(png_found)))]
image_macros += [('JPEG_FOUND', str(int(jpeg_found)))]
if jpeg_found:
print('Building torchvision with JPEG image support')
image_link_flags.append('jpeg')
(jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
use_jpeg = use_jpeg and jpeg_found
if use_jpeg:
print("Building torchvision with JPEG image support")
print(f" libjpeg include path: {jpeg_include}")
print(f" libjpeg lib path: {jpeg_lib}")
image_link_flags.append("jpeg")
if jpeg_conda:
image_library += [jpeg_lib]
image_include += [jpeg_include]
else:
print("Building torchvision without JPEG image support")
image_macros += [("JPEG_FOUND", str(int(use_jpeg)))]
# Locating nvjpeg
# Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
nvjpeg_found = (
extension is CUDAExtension and
CUDA_HOME is not None and
os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))
extension is CUDAExtension
and CUDA_HOME is not None
and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
)
print('NVJPEG found: {0}'.format(nvjpeg_found))
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
if nvjpeg_found:
print('Building torchvision with NVJPEG image support')
image_link_flags.append('nvjpeg')
use_nvjpeg = use_nvjpeg and nvjpeg_found
if use_nvjpeg:
print("Building torchvision with NVJPEG image support")
image_link_flags.append("nvjpeg")
else:
print("Building torchvision without NVJPEG image support")
image_macros += [("NVJPEG_FOUND", str(int(use_nvjpeg)))]
image_path = os.path.join(extensions_dir, "io", "image")
image_src = (
glob.glob(os.path.join(image_path, "*.cpp"))
+ glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
+ glob.glob(os.path.join(image_path, "cpu", "giflib", "*.c"))
)
image_path = os.path.join(extensions_dir, 'io', 'image')
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))
if is_rocm_pytorch:
image_src += glob.glob(os.path.join(image_path, "hip", "*.cpp"))
# we need to exclude this in favor of the hipified source
image_src.remove(os.path.join(image_path, "image.cpp"))
else:
image_src += glob.glob(os.path.join(image_path, "cuda", "*.cpp"))
if png_found or jpeg_found:
ext_modules.append(extension(
'torchvision.image',
ext_modules.append(
extension(
"torchvision.image",
image_src,
include_dirs=image_include + include_dirs + [image_path],
library_dirs=image_library + library_dirs,
define_macros=image_macros,
libraries=image_link_flags,
extra_compile_args=extra_compile_args
))
extra_compile_args=extra_compile_args,
)
)
ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
# Locating ffmpeg
ffmpeg_exe = shutil.which("ffmpeg")
has_ffmpeg = ffmpeg_exe is not None
print("FFmpeg found: {}".format(has_ffmpeg))
ffmpeg_version = None
# FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
# FIXME: causes crash. See the following GitHub issues for more details.
# FIXME: https://github.com/pytorch/pytorch/issues/65000
# FIXME: https://github.com/pytorch/vision/issues/3367
if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
has_ffmpeg = False
if has_ffmpeg:
ffmpeg_libraries = {
'libavcodec',
'libavformat',
'libavutil',
'libswresample',
'libswscale'
}
try:
# This is to check if ffmpeg is installed properly.
ffmpeg_version = subprocess.check_output(["ffmpeg", "-version"])
except subprocess.CalledProcessError:
print("Building torchvision without ffmpeg support")
print(" Error fetching ffmpeg version, ignoring ffmpeg.")
has_ffmpeg = False
use_ffmpeg = use_ffmpeg and has_ffmpeg
if use_ffmpeg:
ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
ffmpeg_bin = os.path.dirname(ffmpeg_exe)
ffmpeg_root = os.path.dirname(ffmpeg_bin)
ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include')
ffmpeg_library_dir = os.path.join(ffmpeg_root, 'lib')
ffmpeg_include_dir = os.path.join(ffmpeg_root, "include")
ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib")
gcc = distutils.spawn.find_executable('gcc')
platform_tag = subprocess.run(
[gcc, '-print-multiarch'], stdout=subprocess.PIPE)
platform_tag = platform_tag.stdout.strip().decode('utf-8')
gcc = os.environ.get("CC", shutil.which("gcc"))
platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE)
platform_tag = platform_tag.stdout.strip().decode("utf-8")
if platform_tag:
# Most probably a Debian-based distribution
ffmpeg_include_dir = [
ffmpeg_include_dir,
os.path.join(ffmpeg_include_dir, platform_tag)
]
ffmpeg_library_dir = [
ffmpeg_library_dir,
os.path.join(ffmpeg_library_dir, platform_tag)
]
ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)]
ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)]
else:
ffmpeg_include_dir = [ffmpeg_include_dir]
ffmpeg_library_dir = [ffmpeg_library_dir]
has_ffmpeg = True
for library in ffmpeg_libraries:
library_found = False
for search_path in ffmpeg_include_dir + include_dirs:
full_path = os.path.join(search_path, library, '*.h')
full_path = os.path.join(search_path, library, "*.h")
library_found |= len(glob.glob(full_path)) > 0
if not library_found:
print(f'{library} header files were not found, disabling ffmpeg support')
has_ffmpeg = False
print("Building torchvision without ffmpeg support")
print(f" {library} header files were not found, disabling ffmpeg support")
use_ffmpeg = False
else:
print("Building torchvision without ffmpeg support")
if has_ffmpeg:
print("ffmpeg include path: {}".format(ffmpeg_include_dir))
print("ffmpeg library_dir: {}".format(ffmpeg_library_dir))
if use_ffmpeg:
print("Building torchvision with ffmpeg support")
print(f" ffmpeg version: {ffmpeg_version}")
print(f" ffmpeg include path: {ffmpeg_include_dir}")
print(f" ffmpeg library_dir: {ffmpeg_library_dir}")
# TorchVision base decoder + video reader
video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video_reader')
video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'decoder')
base_decoder_src = glob.glob(
os.path.join(base_decoder_src_dir, "*.cpp"))
base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder")
base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp"))
# Torchvision video API
videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video')
videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video")
videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
# exclude tests
base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x]
base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x]
combined_src = video_reader_src + base_decoder_src + videoapi_src
ext_modules.append(
CppExtension(
'torchvision.video_reader',
"torchvision.video_reader",
combined_src,
include_dirs=[
base_decoder_src_dir,
......@@ -420,29 +436,89 @@ def get_extensions():
videoapi_src_dir,
extensions_dir,
*ffmpeg_include_dir,
*include_dirs
*include_dirs,
],
library_dirs=ffmpeg_library_dir + library_dirs,
libraries=[
'avcodec',
'avformat',
'avutil',
'swresample',
'swscale',
"avcodec",
"avformat",
"avutil",
"swresample",
"swscale",
],
extra_compile_args=["-std=c++17"] if os.name != "nt" else ["/std:c++17", "/MP"],
extra_link_args=["-std=c++17" if os.name != "nt" else "/std:c++17"],
)
)
# Locating video codec
# CUDA_HOME should be set to the cuda root directory.
# TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
# video codec header files and libraries respectively.
video_codec_found = (
extension is CUDAExtension
and CUDA_HOME is not None
and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include])
and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include])
and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs])
)
use_video_codec = use_video_codec and video_codec_found
if (
use_video_codec
and use_ffmpeg
and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print("Building torchvision with video codec support")
gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu")
gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
cuda_libs = os.path.join(CUDA_HOME, "lib64")
cuda_inc = os.path.join(CUDA_HOME, "include")
ext_modules.append(
extension(
"torchvision.Decoder",
gpu_decoder_src,
include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs],
libraries=[
"avcodec",
"avformat",
"avutil",
"swresample",
"swscale",
"nvcuvid",
"cuda",
"cudart",
"z",
"pthread",
"dl",
"nppicc",
],
extra_compile_args=["-std=c++14"] if os.name != 'nt' else ['/std:c++14', '/MP'],
extra_link_args=["-std=c++14" if os.name != 'nt' else '/std:c++14'],
extra_compile_args=extra_compile_args,
)
)
else:
print("Building torchvision without video codec support")
if (
use_video_codec
and use_ffmpeg
and not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print(
" The installed version of ffmpeg is missing the header file 'bsf.h' which is "
" required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
" `conda install -c conda-forge ffmpeg`."
)
return ext_modules
class clean(distutils.command.clean.clean):
def run(self):
with open('.gitignore', 'r') as f:
with open(".gitignore") as f:
ignores = f.read()
for wildcard in filter(None, ignores.split('\n')):
for wildcard in filter(None, ignores.split("\n")):
for filename in glob.glob(wildcard):
try:
os.remove(filename)
......@@ -454,37 +530,37 @@ class clean(distutils.command.clean.clean):
if __name__ == "__main__":
print("Building wheel {}-{}".format(package_name, version))
print(f"Building wheel {package_name}-{version}")
write_version_file()
with open('README.rst') as f:
with open("README.md") as f:
readme = f.read()
setup(
# Metadata
name=package_name,
version=version,
author='PyTorch Core Team',
author_email='soumith@pytorch.org',
url='https://github.com/pytorch/vision',
description='image and video datasets and models for torch deep learning',
author="PyTorch Core Team",
author_email="soumith@pytorch.org",
url="https://github.com/pytorch/vision",
description="image and video datasets and models for torch deep learning",
long_description=readme,
license='BSD',
long_description_content_type="text/markdown",
license="BSD",
# Package info
packages=find_packages(exclude=('test',)),
package_data={
package_name: ['*.dll', '*.dylib', '*.so']
},
packages=find_packages(exclude=("test",)),
package_data={package_name: ["*.dll", "*.dylib", "*.so"]},
zip_safe=False,
install_requires=requirements,
extras_require={
"gdown": ["gdown>=4.7.3"],
"scipy": ["scipy"],
},
ext_modules=get_extensions(),
python_requires=">=3.8",
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean,
}
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
)
"""This is a temporary module and should be removed as soon as torch.testing.assert_equal is supported."""
# TODO: remove this as soon torch.testing.assert_equal is supported
import functools
import torch.testing
__all__ = ["assert_equal"]
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
test/assets/fakedata/draw_boxes_util.png

560 Bytes | W: | H:

test/assets/fakedata/draw_boxes_util.png

855 Bytes | W: | H:

test/assets/fakedata/draw_boxes_util.png
test/assets/fakedata/draw_boxes_util.png
test/assets/fakedata/draw_boxes_util.png
test/assets/fakedata/draw_boxes_util.png
  • 2-up
  • Swipe
  • Onion skin
import os
from collections import defaultdict
from numbers import Number
from typing import Any, List
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torchvision.models._api import Weights
aten = torch.ops.aten
quantized = torch.ops.quantized
def get_shape(i):
if isinstance(i, torch.Tensor):
return i.shape
elif hasattr(i, "weight"):
return i.weight().shape
else:
raise ValueError(f"Unknown type {type(i)}")
def prod(x):
res = 1
for i in x:
res *= i
return res
def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flop = prod(input_shapes[0]) * input_shapes[-1][-1]
return flop
def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [get_shape(v) for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flop = n * c * t * d
return flop
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop
def conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def quant_conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for quantized convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
return conv_flop_count(x_shape, w_shape, out_shape, transposed=False)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = get_shape(outputs[0])
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = get_shape(outputs[1])
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
def scaled_dot_product_flash_attention_flop(inputs: List[Any], outputs: List[Any]):
# FIXME: this needs to count the flops of this kernel
# https://github.com/pytorch/pytorch/blob/207b06d099def9d9476176a1842e88636c1f714f/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L52-L267
return 0
flop_mapping = {
aten.mm: matmul_flop,
aten.matmul: matmul_flop,
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
quantized.conv2d: quant_conv_flop,
quantized.conv2d_relu: quant_conv_flop,
aten._scaled_dot_product_flash_attention: scaled_dot_product_flash_attention_flop,
}
unmapped_ops = set()
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
class FlopCounterMode(TorchDispatchMode):
def __init__(self, model=None):
self.flop_counts = defaultdict(lambda: defaultdict(int))
self.parents = ["Global"]
# global mod
if model is not None:
for name, module in dict(model.named_children()).items():
module.register_forward_pre_hook(self.enter_module(name))
module.register_forward_hook(self.exit_module(name))
def enter_module(self, name):
def f(module, inputs):
self.parents.append(name)
inputs = normalize_tuple(inputs)
out = self.create_backwards_pop(name)(*inputs)
return out
return f
def exit_module(self, name):
def f(module, inputs, outputs):
assert self.parents[-1] == name
self.parents.pop()
outputs = normalize_tuple(outputs)
return self.create_backwards_push(name)(*outputs)
return f
def create_backwards_push(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self.parents.append(name)
return grad_outs
return PushState.apply
def create_backwards_pop(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
assert self.parents[-1] == name
self.parents.pop()
return grad_outs
return PopState.apply
def __enter__(self):
self.flop_counts.clear()
super().__enter__()
def __exit__(self, *args):
# print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS")
# for mod in self.flop_counts.keys():
# print(f"Module: ", mod)
# for k, v in self.flop_counts[mod].items():
# print(f"{k}: {v / 1e9} GFLOPS")
# print()
super().__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in flop_mapping:
flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
else:
unmapped_ops.add(func_packet)
return out
def get_flops(self):
return sum(self.flop_counts["Global"].values()) / 1e9
def get_dims(module_name, height, width):
# detection models have curated input sizes
if module_name == "detection":
# we can feed a batch of 1 for detection model instead of a list of 1 image
dims = (3, height, width)
elif module_name == "video":
# hard-coding the time dimension to size 16
dims = (1, 16, 3, height, width)
else:
dims = (1, 3, height, width)
return dims
def get_ops(model: torch.nn.Module, weight: Weights, height=512, width=512):
module_name = model.__module__.split(".")[-2]
dims = get_dims(module_name=module_name, height=height, width=width)
input_tensor = torch.randn(dims)
# try:
preprocess = weight.transforms()
if module_name == "optical_flow":
inp = preprocess(input_tensor, input_tensor)
else:
# hack to enable mod(*inp) for optical_flow models
inp = [preprocess(input_tensor)]
model.eval()
flop_counter = FlopCounterMode(model)
with flop_counter:
# detection models expect a list of 3d tensors as inputs
if module_name == "detection":
model(inp)
else:
model(*inp)
flops = flop_counter.get_flops()
return round(flops, 3)
def get_file_size_mb(weight):
weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints", weight.url.split("/")[-1])
weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024
return round(weights_size_mb, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment