Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
import asyncio
import sys
from pathlib import Path
from time import perf_counter
from urllib.parse import urlsplit
import aiofiles
import aiohttp
from torchvision import models
from tqdm.asyncio import tqdm
async def main(download_root):
download_root.mkdir(parents=True, exist_ok=True)
urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))}
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session:
await tqdm.gather(*[download(download_root, session, url) for url in urls])
async def download(download_root, session, url):
response = await session.get(url, params=dict(source="ci"))
assert response.ok
file_name = Path(urlsplit(url).path).name
async with aiofiles.open(download_root / file_name, "wb") as f:
async for data in response.content.iter_any():
await f.write(data)
if __name__ == "__main__":
download_root = (
(Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve()
)
print(f"Downloading model weights to {download_root}")
start = perf_counter()
asyncio.get_event_loop().run_until_complete(main(download_root))
stop = perf_counter()
minutes, seconds = divmod(stop - start, 60)
print(f"Download took {minutes:2.0f}m {seconds:2.0f}s")
#!/bin/bash
if [ -z $1 ]
then
echo "Commit hash is required to be passed when running this script."
echo "./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>"
exit 1
fi
commit_hash=$1
if [ -z $2 ]
then
echo "Fork name is required to be passed when running this script."
echo "./fbcode_to_main_sync.sh <commit_hash> <fork_name> <fork_main_branch>"
exit 1
fi
fork_name=$2
if [ -z $3 ]
then
fork_main_branch="main"
else
fork_main_branch=$3
fi
from_branch="fbsync"
git stash
git checkout $from_branch
git pull
# Add random prefix in the new branch name to keep it unique per run
prefix=$RANDOM
IFS='
'
for line in $(git log --pretty=oneline "$commit_hash"..HEAD)
do
if [[ $line != *\[fbsync\]* ]]
then
echo "Parsing $line"
hash=$(echo $line | cut -f1 -d' ')
git checkout $fork_main_branch
git checkout -B cherrypick_${prefix}_${hash}
git cherry-pick -x "$hash"
git push $fork_name cherrypick_${prefix}_${hash}
git checkout $from_branch
fi
done
echo "Please review the PRs, add [FBCode->GH] prefix in the title and publish them."
# In[1]:
import pandas as pd
# In[2]:
data_filename = "data.json"
df = pd.read_json(data_filename).T
df.tail()
# In[3]:
all_labels = {lbl for labels in df["labels"] for lbl in labels}
all_labels
# In[4]:
# Add one column per label
for label in all_labels:
df[label] = df["labels"].apply(lambda labels_list: label in labels_list)
df.head()
# In[5]:
# Add a clean "module" column. It contains tuples since PRs can have more than one module.
# Maybe we should include "topics" in that column as well?
all_modules = { # mapping: full name -> clean name
label: "".join(label.split(" ")[1:]) for label in all_labels if label.startswith("module")
}
# We use an ugly loop, but whatever ¯\_(ツ)_/¯
df["module"] = [[] for _ in range(len(df))]
for i, row in df.iterrows():
for full_name, clean_name in all_modules.items():
if full_name in row["labels"]:
row["module"].append(clean_name)
df["module"] = df.module.apply(tuple)
df.head()
# In[6]:
mod_df = df.set_index("module").sort_index()
mod_df.tail()
# In[7]:
# All improvement PRs
mod_df[mod_df["enhancement"]].head()
# In[8]:
# improvement f module
# note: don't filter module name on the index as the index contain tuples with non-exclusive values
# Use the boolean column instead
mod_df[mod_df["enhancement"] & mod_df["module: transforms"]]
# In[9]:
def format_prs(mod_df, exclude_prototype=True):
out = []
for idx, row in mod_df.iterrows():
if exclude_prototype and "prototype" in row and row["prototype"]:
continue
modules = idx
# Put "documentation" and "tests" first for sorting to be dece
for last_module in ("documentation", "tests"):
if last_module in modules:
modules = [m for m in modules if m != last_module] + [last_module]
module = f"[{', '.join(modules)}]"
module = module.replace("referencescripts", "reference scripts")
module = module.replace("code", "reference scripts")
out.append(f"{module} {row['title']}")
return "\n".join(out)
# In[10]:
included_prs = pd.DataFrame()
# If labels are accurate, this shouhld generate most of the release notes already
# We keep track of the included PRs to figure out which ones are missing
for section_title, module_idx in (
("Backward-incompatible changes", "bc-breaking"),
("Deprecations", "deprecation"),
("New Features", "new feature"),
("Improvements", "enhancement"),
("Bug Fixes", "bug"),
("Code Quality", "code quality"),
):
if module_idx in mod_df:
print(f"## {section_title}")
print()
tmp_df = mod_df[mod_df[module_idx]]
included_prs = pd.concat([included_prs, tmp_df])
print(format_prs(tmp_df))
print()
# In[11]:
# Missing PRs are these ones... classify them manually
missing_prs = pd.concat([mod_df, included_prs]).drop_duplicates(subset="pr_number", keep=False)
print(format_prs(missing_prs))
# In[12]:
# Generate list of contributors
print()
print("## Contributors")
previous_release = "c35d3855ccbfa6a36e6ae6337a1f2c721c1f1e78"
current_release = "5181a854d8b127cf465cd22a67c1b5aaf6ccae05"
print(
f"{{ git shortlog -s {previous_release}..{current_release} | cut -f2- & git log -s {previous_release}..{current_release} | grep Co-authored | cut -f2- -d: | cut -f1 -d\\< | sed 's/^ *//;s/ *//' ; }} | sort --ignore-case | uniq | tr '\\n' ';' | sed 's/;/, /g;s/,//' | fold -s"
)
# In[13]:
# Utility to extract PR numbers only from multiple lines, useful to bundle all
# the docs changes for example:
import re
s = """
[] Remove unnecessary dependency from macOS/Conda binaries (#8077)
[rocm] [ROCm] remove HCC references (#8070)
"""
print(", ".join(re.findall("(#\\d+)", s)))
import json
import locale
import os
import re
import subprocess
from collections import namedtuple
from os.path import expanduser
import requests
Features = namedtuple(
"Features",
[
"title",
"body",
"pr_number",
"files_changed",
"labels",
],
)
def dict_to_features(dct):
return Features(
title=dct["title"],
body=dct["body"],
pr_number=dct["pr_number"],
files_changed=dct["files_changed"],
labels=dct["labels"],
)
def features_to_dict(features):
return dict(features._asdict())
def run(command):
"""Returns (return-code, stdout, stderr)"""
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
output, err = p.communicate()
rc = p.returncode
enc = locale.getpreferredencoding()
output = output.decode(enc)
err = err.decode(enc)
return rc, output.strip(), err.strip()
def commit_body(commit_hash):
cmd = f"git log -n 1 --pretty=format:%b {commit_hash}"
ret, out, err = run(cmd)
return out if ret == 0 else None
def commit_title(commit_hash):
cmd = f"git log -n 1 --pretty=format:%s {commit_hash}"
ret, out, err = run(cmd)
return out if ret == 0 else None
def commit_files_changed(commit_hash):
cmd = f"git diff-tree --no-commit-id --name-only -r {commit_hash}"
ret, out, err = run(cmd)
return out.split("\n") if ret == 0 else None
def parse_pr_number(body, commit_hash, title):
regex = r"(#[0-9]+)"
matches = re.findall(regex, title)
if len(matches) == 0:
if "revert" not in title.lower() and "updating submodules" not in title.lower():
print(f"[{commit_hash}: {title}] Could not parse PR number, ignoring PR")
return None
if len(matches) > 1:
print(f"[{commit_hash}: {title}] Got two PR numbers, using the last one")
return matches[-1][1:]
return matches[0][1:]
def get_ghstack_token():
pattern = "github_oauth = (.*)"
with open(expanduser("~/.ghstackrc"), "r+") as f:
config = f.read()
matches = re.findall(pattern, config)
if len(matches) == 0:
raise RuntimeError("Can't find a github oauth token")
return matches[0]
token = get_ghstack_token()
headers = {"Authorization": f"token {token}"}
def run_query(query):
request = requests.post("https://api.github.com/graphql", json={"query": query}, headers=headers)
if request.status_code == 200:
return request.json()
else:
raise Exception(f"Query failed to run by returning code of {request.status_code}. {query}")
def gh_labels(pr_number):
query = f"""
{{
repository(owner: "pytorch", name: "vision") {{
pullRequest(number: {pr_number}) {{
labels(first: 10) {{
edges {{
node {{
name
}}
}}
}}
}}
}}
}}
"""
query = run_query(query)
edges = query["data"]["repository"]["pullRequest"]["labels"]["edges"]
return [edge["node"]["name"] for edge in edges]
def get_features(commit_hash, return_dict=False):
title, body, files_changed = (
commit_title(commit_hash),
commit_body(commit_hash),
commit_files_changed(commit_hash),
)
pr_number = parse_pr_number(body, commit_hash, title)
labels = []
if pr_number is not None:
labels = gh_labels(pr_number)
result = Features(title, body, pr_number, files_changed, labels)
if return_dict:
return features_to_dict(result)
return result
class CommitDataCache:
def __init__(self, path="results/data.json"):
self.path = path
self.data = {}
if os.path.exists(path):
self.data = self.read_from_disk()
def get(self, commit):
if commit not in self.data.keys():
# Fetch and cache the data
self.data[commit] = get_features(commit)
self.write_to_disk()
return self.data[commit]
def read_from_disk(self):
with open(self.path) as f:
data = json.load(f)
data = {commit: dict_to_features(dct) for commit, dct in data.items()}
return data
def write_to_disk(self):
data = {commit: features._asdict() for commit, features in self.data.items()}
with open(self.path, "w") as f:
json.dump(data, f)
def get_commits_between(base_version, new_version):
cmd = f"git merge-base {base_version} {new_version}"
rc, merge_base, _ = run(cmd)
assert rc == 0
# Returns a list of something like
# b33e38ec47 Allow a higher-precision step type for Vec256::arange (#34555)
cmd = f"git log --reverse --oneline {merge_base}..{new_version}"
rc, commits, _ = run(cmd)
assert rc == 0
log_lines = commits.split("\n")
hashes, titles = zip(*[log_line.split(" ", 1) for log_line in log_lines])
return hashes, titles
def convert_to_dataframes(feature_list):
import pandas as pd
df = pd.DataFrame.from_records(feature_list, columns=Features._fields)
return df
def main(base_version, new_version):
hashes, titles = get_commits_between(base_version, new_version)
cdc = CommitDataCache("data.json")
for idx, commit in enumerate(hashes):
if idx % 10 == 0:
print(f"{idx} / {len(hashes)}")
cdc.get(commit)
return cdc
if __name__ == "__main__":
# d = get_features('2ab93592529243862ce8ad5b6acf2628ef8d0dc8')
# print(d)
# hashes, titles = get_commits_between("tags/v0.9.0", "fc852f3b39fe25dd8bf1dedee8f19ea04aa84c15")
# Usage: change the tags below accordingly to the current release, then save the json with
# cdc.write_to_disk().
# Then you can use classify_prs.py (as a notebook)
# to open the json and generate the release notes semi-automatically.
cdc = main("tags/v0.9.0", "fc852f3b39fe25dd8bf1dedee8f19ea04aa84c15")
from IPython import embed
embed()
...@@ -2,14 +2,21 @@ ...@@ -2,14 +2,21 @@
universal=1 universal=1
[metadata] [metadata]
license_file = LICENSE license_files = LICENSE
[pep8] [pep8]
max-line-length = 120 max-line-length = 120
[flake8] [flake8]
# note: we ignore all 501s (line too long) anyway as they're taken care of by black
max-line-length = 120 max-line-length = 120
ignore = F401,E402,F403,W503,W504,F821 ignore = E203, E402, W503, W504, F821, E501, B, C4, EXE
per-file-ignores =
__init__.py: F401, F403, F405
./hubconf.py: F401
torchvision/models/mobilenet.py: F401, F403
torchvision/models/quantization/mobilenet.py: F401, F403
test/smoke_test.py: F401
exclude = venv exclude = venv
[pydocstyle] [pydocstyle]
......
import os
import io
import sys
from setuptools import setup, find_packages
from pkg_resources import parse_version, get_distribution, DistributionNotFound
import subprocess
import distutils.command.clean import distutils.command.clean
import distutils.spawn import distutils.spawn
import glob import glob
import os
import shutil import shutil
import subprocess
import sys
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from pkg_resources import DistributionNotFound, get_distribution, parse_version
from torch.utils.hipify import hipify_python from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDA_HOME, CUDAExtension
def read(*names, **kwargs): def read(*names, **kwargs):
with io.open( with open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")
) as fp:
return fp.read() return fp.read()
...@@ -31,60 +26,61 @@ def get_dist(pkgname): ...@@ -31,60 +26,61 @@ def get_dist(pkgname):
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
version_txt = os.path.join(cwd, 'version.txt') version_txt = os.path.join(cwd, "version.txt")
with open(version_txt, 'r') as f: with open(version_txt) as f:
version = f.readline().strip() version = f.readline().strip()
sha = 'Unknown' sha = "Unknown"
package_name = 'torchvision' package_name = "torchvision"
try: try:
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
except Exception: except Exception:
pass pass
if os.getenv('BUILD_VERSION'): if os.getenv("BUILD_VERSION"):
version = os.getenv('BUILD_VERSION') version = os.getenv("BUILD_VERSION")
elif sha != 'Unknown': elif sha != "Unknown":
version += '+' + sha[:7] version += "+" + sha[:7]
def write_version_file(): def write_version_file():
version_path = os.path.join(cwd, 'torchvision', 'version.py') version_path = os.path.join(cwd, "torchvision", "version.py")
with open(version_path, 'w') as f: with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version)) f.write(f"__version__ = '{version}'\n")
f.write("git_version = {}\n".format(repr(sha))) f.write(f"git_version = {repr(sha)}\n")
f.write("from torchvision.extension import _check_cuda_version\n") f.write("from torchvision.extension import _check_cuda_version\n")
f.write("if _check_cuda_version() > 0:\n") f.write("if _check_cuda_version() > 0:\n")
f.write(" cuda = _check_cuda_version()\n") f.write(" cuda = _check_cuda_version()\n")
pytorch_dep = 'torch' pytorch_dep = "torch"
if os.getenv('PYTORCH_VERSION'): if os.getenv("PYTORCH_VERSION"):
pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') pytorch_dep += "==" + os.getenv("PYTORCH_VERSION")
requirements = [ requirements = [
'numpy', "numpy",
pytorch_dep, pytorch_dep,
] ]
pillow_ver = ' >= 5.3.0' # Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' pillow_ver = " >= 5.3.0, !=8.3.*"
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
requirements.append(pillow_req + pillow_ver) requirements.append(pillow_req + pillow_ver)
def find_library(name, vision_include): def find_library(name, vision_include):
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
build_prefix = os.environ.get('BUILD_PREFIX', None) build_prefix = os.environ.get("BUILD_PREFIX", None)
is_conda_build = build_prefix is not None is_conda_build = build_prefix is not None
library_found = False library_found = False
conda_installed = False conda_installed = False
lib_folder = None lib_folder = None
include_folder = None include_folder = None
library_header = '{0}.h'.format(name) library_header = f"{name}.h"
# Lookup in TORCHVISION_INCLUDE or in the package file # Lookup in TORCHVISION_INCLUDE or in the package file
package_path = [os.path.join(this_dir, 'torchvision')] package_path = [os.path.join(this_dir, "torchvision")]
for folder in vision_include + package_path: for folder in vision_include + package_path:
candidate_path = os.path.join(folder, library_header) candidate_path = os.path.join(folder, library_header)
library_found = os.path.exists(candidate_path) library_found = os.path.exists(candidate_path)
...@@ -92,64 +88,89 @@ def find_library(name, vision_include): ...@@ -92,64 +88,89 @@ def find_library(name, vision_include):
break break
if not library_found: if not library_found:
print('Running build on conda-build: {0}'.format(is_conda_build)) print(f"Running build on conda-build: {is_conda_build}")
if is_conda_build: if is_conda_build:
# Add conda headers/libraries # Add conda headers/libraries
if os.name == 'nt': if os.name == "nt":
build_prefix = os.path.join(build_prefix, 'Library') build_prefix = os.path.join(build_prefix, "Library")
include_folder = os.path.join(build_prefix, 'include') include_folder = os.path.join(build_prefix, "include")
lib_folder = os.path.join(build_prefix, 'lib') lib_folder = os.path.join(build_prefix, "lib")
library_header_path = os.path.join( library_header_path = os.path.join(include_folder, library_header)
include_folder, library_header)
library_found = os.path.isfile(library_header_path) library_found = os.path.isfile(library_header_path)
conda_installed = library_found conda_installed = library_found
else: else:
# Check if using Anaconda to produce wheels # Check if using Anaconda to produce wheels
conda = distutils.spawn.find_executable('conda') conda = shutil.which("conda")
is_conda = conda is not None is_conda = conda is not None
print('Running build on conda: {0}'.format(is_conda)) print(f"Running build on conda: {is_conda}")
if is_conda: if is_conda:
python_executable = sys.executable python_executable = sys.executable
py_folder = os.path.dirname(python_executable) py_folder = os.path.dirname(python_executable)
if os.name == 'nt': if os.name == "nt":
env_path = os.path.join(py_folder, 'Library') env_path = os.path.join(py_folder, "Library")
else: else:
env_path = os.path.dirname(py_folder) env_path = os.path.dirname(py_folder)
lib_folder = os.path.join(env_path, 'lib') lib_folder = os.path.join(env_path, "lib")
include_folder = os.path.join(env_path, 'include') include_folder = os.path.join(env_path, "include")
library_header_path = os.path.join( library_header_path = os.path.join(include_folder, library_header)
include_folder, library_header)
library_found = os.path.isfile(library_header_path) library_found = os.path.isfile(library_header_path)
conda_installed = library_found conda_installed = library_found
if not library_found: if not library_found:
if sys.platform == 'linux': if sys.platform == "linux":
library_found = os.path.exists('/usr/include/{0}'.format( library_found = os.path.exists(f"/usr/include/{library_header}")
library_header)) library_found = library_found or os.path.exists(f"/usr/local/include/{library_header}")
library_found = library_found or os.path.exists(
'/usr/local/include/{0}'.format(library_header))
return library_found, conda_installed, include_folder, lib_folder return library_found, conda_installed, include_folder, lib_folder
def get_extensions(): def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops', main_file = (
'*.cpp')) glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = ( source_cpu = (
glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp')) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
) )
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}")
force_mps = os.getenv("FORCE_MPS", "0") == "1"
print(f" FORCE_MPS: {force_mps}")
debug_mode = os.getenv("DEBUG", "0") == "1"
print(f" DEBUG: {debug_mode}")
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
print(f" TORCHVISION_USE_PNG: {use_png}")
use_jpeg = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
print(f" TORCHVISION_USE_JPEG: {use_jpeg}")
use_nvjpeg = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
print(f" TORCHVISION_USE_NVJPEG: {use_nvjpeg}")
use_ffmpeg = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
print(f" TORCHVISION_USE_FFMPEG: {use_ffmpeg}")
use_video_codec = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1"
print(f" TORCHVISION_USE_VIDEO_CODEC: {use_video_codec}")
nvcc_flags = os.getenv("NVCC_FLAGS", "")
print(f" NVCC_FLAGS: {nvcc_flags}")
is_rocm_pytorch = False is_rocm_pytorch = False
if torch.__version__ >= '1.5':
if torch.__version__ >= "1.5":
from torch.utils.cpp_extension import ROCM_HOME from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
is_rocm_pytorch = (torch.version.hip is not None) and (ROCM_HOME is not None)
if is_rocm_pytorch: if is_rocm_pytorch:
from torch.utils.hipify import hipify_python
hipify_python.hipify( hipify_python.hipify(
project_directory=this_dir, project_directory=this_dir,
output_directory=this_dir, output_directory=this_dir,
...@@ -157,68 +178,52 @@ def get_extensions(): ...@@ -157,68 +178,52 @@ def get_extensions():
show_detailed=True, show_detailed=True,
is_pytorch_extension=True, is_pytorch_extension=True,
) )
source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'hip', '*.hip')) source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip"))
# Copy over additional files # Copy over additional files
for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"): for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"):
shutil.copy(file, "torchvision/csrc/ops/hip") shutil.copy(file, "torchvision/csrc/ops/hip")
else: else:
source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu')) source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp'))
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
if compile_cpp_tests:
test_dir = os.path.join(this_dir, 'test')
models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models')
test_file = glob.glob(os.path.join(test_dir, '*.cpp'))
source_models = glob.glob(os.path.join(models_dir, '*.cpp'))
test_file = [os.path.join(test_dir, s) for s in test_file]
source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models
tests_include_dirs = [test_dir, models_dir]
define_macros = [] define_macros = []
extra_compile_args = {'cxx': []} extra_compile_args = {"cxx": []}
if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) \ if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or force_cuda:
or os.getenv('FORCE_CUDA', '0') == '1':
extension = CUDAExtension extension = CUDAExtension
sources += source_cuda sources += source_cuda
if not is_rocm_pytorch: if not is_rocm_pytorch:
define_macros += [('WITH_CUDA', None)] define_macros += [("WITH_CUDA", None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') if nvcc_flags == "":
if nvcc_flags == '':
nvcc_flags = [] nvcc_flags = []
else: else:
nvcc_flags = nvcc_flags.split(' ') nvcc_flags = nvcc_flags.split(" ")
else: else:
define_macros += [('WITH_HIP', None)] define_macros += [("WITH_HIP", None)]
nvcc_flags = [] nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps
if sys.platform == 'win32': if sys.platform == "win32":
define_macros += [('torchvision_EXPORTS', None)] define_macros += [("torchvision_EXPORTS", None)]
extra_compile_args["cxx"].append("/MP")
extra_compile_args['cxx'].append('/MP')
debug_mode = os.getenv('DEBUG', '0') == '1'
if debug_mode: if debug_mode:
print("Compile in debug mode") print("Compiling in debug mode")
extra_compile_args['cxx'].append("-g") extra_compile_args["cxx"].append("-g")
extra_compile_args['cxx'].append("-O0") extra_compile_args["cxx"].append("-O0")
if "nvcc" in extra_compile_args: if "nvcc" in extra_compile_args:
# we have to remove "-OX" and "-g" flag if exists and append # we have to remove "-OX" and "-g" flag if exists and append
nvcc_flags = extra_compile_args["nvcc"] nvcc_flags = extra_compile_args["nvcc"]
extra_compile_args["nvcc"] = [ extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
f for f in nvcc_flags if not ("-O" in f or "-g" in f)
]
extra_compile_args["nvcc"].append("-O0") extra_compile_args["nvcc"].append("-O0")
extra_compile_args["nvcc"].append("-g") extra_compile_args["nvcc"].append("-g")
else:
print("Compiling with debug mode OFF")
extra_compile_args["cxx"].append("-g0")
sources = [os.path.join(extensions_dir, s) for s in sources] sources = [os.path.join(extensions_dir, s) for s in sources]
...@@ -226,31 +231,19 @@ def get_extensions(): ...@@ -226,31 +231,19 @@ def get_extensions():
ext_modules = [ ext_modules = [
extension( extension(
'torchvision._C', "torchvision._C",
sorted(sources), sorted(sources),
include_dirs=include_dirs, include_dirs=include_dirs,
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
) )
] ]
if compile_cpp_tests:
ext_modules.append(
extension(
'torchvision._C_tests',
tests,
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
)
# ------------------- Torchvision extra extensions ------------------------ # ------------------- Torchvision extra extensions ------------------------
vision_include = os.environ.get('TORCHVISION_INCLUDE', None) vision_include = os.environ.get("TORCHVISION_INCLUDE", None)
vision_library = os.environ.get('TORCHVISION_LIBRARY', None) vision_library = os.environ.get("TORCHVISION_LIBRARY", None)
vision_include = (vision_include.split(os.pathsep) vision_include = vision_include.split(os.pathsep) if vision_include is not None else []
if vision_include is not None else []) vision_library = vision_library.split(os.pathsep) if vision_library is not None else []
vision_library = (vision_library.split(os.pathsep)
if vision_library is not None else [])
include_dirs += vision_include include_dirs += vision_include
library_dirs = vision_library library_dirs = vision_library
...@@ -261,158 +254,181 @@ def get_extensions(): ...@@ -261,158 +254,181 @@ def get_extensions():
image_link_flags = [] image_link_flags = []
# Locating libPNG # Locating libPNG
libpng = distutils.spawn.find_executable('libpng-config') libpng = shutil.which("libpng-config")
pngfix = distutils.spawn.find_executable('pngfix') pngfix = shutil.which("pngfix")
png_found = libpng is not None or pngfix is not None png_found = libpng is not None or pngfix is not None
print('PNG found: {0}'.format(png_found))
if png_found: use_png = use_png and png_found
if use_png:
print("Found PNG library")
if libpng is not None: if libpng is not None:
# Linux / Mac # Linux / Mac
png_version = subprocess.run([libpng, '--version'], min_version = "1.6.0"
stdout=subprocess.PIPE) png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
png_version = png_version.stdout.strip().decode('utf-8') png_version = png_version.stdout.strip().decode("utf-8")
print('libpng version: {0}'.format(png_version))
png_version = parse_version(png_version) png_version = parse_version(png_version)
if png_version >= parse_version("1.6.0"): if png_version >= parse_version(min_version):
print('Building torchvision with PNG image support') print("Building torchvision with PNG image support")
png_lib = subprocess.run([libpng, '--libdir'], png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE)
stdout=subprocess.PIPE) png_lib = png_lib.stdout.strip().decode("utf-8")
png_lib = png_lib.stdout.strip().decode('utf-8') if "disabled" not in png_lib:
if 'disabled' not in png_lib:
image_library += [png_lib] image_library += [png_lib]
png_include = subprocess.run([libpng, '--I_opts'], png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
stdout=subprocess.PIPE) png_include = png_include.stdout.strip().decode("utf-8")
png_include = png_include.stdout.strip().decode('utf-8') _, png_include = png_include.split("-I")
_, png_include = png_include.split('-I')
print('libpng include path: {0}'.format(png_include))
image_include += [png_include] image_include += [png_include]
image_link_flags.append('png') image_link_flags.append("png")
print(f" libpng version: {png_version}")
print(f" libpng include path: {png_include}")
else: else:
print('libpng installed version is less than 1.6.0, ' print("Could not add PNG image support to torchvision:")
'disabling PNG support') print(f" libpng minimum version {min_version}, found {png_version}")
png_found = False use_png = False
else: else:
# Windows # Windows
png_lib = os.path.join( png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib")
os.path.dirname(os.path.dirname(pngfix)), 'lib') png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16")
png_include = os.path.join(os.path.dirname(
os.path.dirname(pngfix)), 'include', 'libpng16')
image_library += [png_lib] image_library += [png_lib]
image_include += [png_include] image_include += [png_include]
image_link_flags.append('libpng') image_link_flags.append("libpng")
else:
print("Building torchvision without PNG image support")
image_macros += [("PNG_FOUND", str(int(use_png)))]
# Locating libjpeg # Locating libjpeg
(jpeg_found, jpeg_conda, (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
jpeg_include, jpeg_lib) = find_library('jpeglib', vision_include)
use_jpeg = use_jpeg and jpeg_found
print('JPEG found: {0}'.format(jpeg_found)) if use_jpeg:
image_macros += [('PNG_FOUND', str(int(png_found)))] print("Building torchvision with JPEG image support")
image_macros += [('JPEG_FOUND', str(int(jpeg_found)))] print(f" libjpeg include path: {jpeg_include}")
if jpeg_found: print(f" libjpeg lib path: {jpeg_lib}")
print('Building torchvision with JPEG image support') image_link_flags.append("jpeg")
image_link_flags.append('jpeg')
if jpeg_conda: if jpeg_conda:
image_library += [jpeg_lib] image_library += [jpeg_lib]
image_include += [jpeg_include] image_include += [jpeg_include]
else:
print("Building torchvision without JPEG image support")
image_macros += [("JPEG_FOUND", str(int(use_jpeg)))]
# Locating nvjpeg # Locating nvjpeg
# Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
nvjpeg_found = ( nvjpeg_found = (
extension is CUDAExtension and extension is CUDAExtension
CUDA_HOME is not None and and CUDA_HOME is not None
os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
) )
print('NVJPEG found: {0}'.format(nvjpeg_found)) use_nvjpeg = use_nvjpeg and nvjpeg_found
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] if use_nvjpeg:
if nvjpeg_found: print("Building torchvision with NVJPEG image support")
print('Building torchvision with NVJPEG image support') image_link_flags.append("nvjpeg")
image_link_flags.append('nvjpeg') else:
print("Building torchvision without NVJPEG image support")
image_macros += [("NVJPEG_FOUND", str(int(use_nvjpeg)))]
image_path = os.path.join(extensions_dir, "io", "image")
image_src = (
glob.glob(os.path.join(image_path, "*.cpp"))
+ glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
+ glob.glob(os.path.join(image_path, "cpu", "giflib", "*.c"))
)
image_path = os.path.join(extensions_dir, 'io', 'image') if is_rocm_pytorch:
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) image_src += glob.glob(os.path.join(image_path, "hip", "*.cpp"))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) # we need to exclude this in favor of the hipified source
image_src.remove(os.path.join(image_path, "image.cpp"))
else:
image_src += glob.glob(os.path.join(image_path, "cuda", "*.cpp"))
if png_found or jpeg_found: ext_modules.append(
ext_modules.append(extension( extension(
'torchvision.image', "torchvision.image",
image_src, image_src,
include_dirs=image_include + include_dirs + [image_path], include_dirs=image_include + include_dirs + [image_path],
library_dirs=image_library + library_dirs, library_dirs=image_library + library_dirs,
define_macros=image_macros, define_macros=image_macros,
libraries=image_link_flags, libraries=image_link_flags,
extra_compile_args=extra_compile_args extra_compile_args=extra_compile_args,
)) )
)
ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') # Locating ffmpeg
ffmpeg_exe = shutil.which("ffmpeg")
has_ffmpeg = ffmpeg_exe is not None has_ffmpeg = ffmpeg_exe is not None
print("FFmpeg found: {}".format(has_ffmpeg)) ffmpeg_version = None
# FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
# FIXME: causes crash. See the following GitHub issues for more details.
# FIXME: https://github.com/pytorch/pytorch/issues/65000
# FIXME: https://github.com/pytorch/vision/issues/3367
if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
has_ffmpeg = False
if has_ffmpeg: if has_ffmpeg:
ffmpeg_libraries = { try:
'libavcodec', # This is to check if ffmpeg is installed properly.
'libavformat', ffmpeg_version = subprocess.check_output(["ffmpeg", "-version"])
'libavutil', except subprocess.CalledProcessError:
'libswresample', print("Building torchvision without ffmpeg support")
'libswscale' print(" Error fetching ffmpeg version, ignoring ffmpeg.")
} has_ffmpeg = False
use_ffmpeg = use_ffmpeg and has_ffmpeg
if use_ffmpeg:
ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
ffmpeg_bin = os.path.dirname(ffmpeg_exe) ffmpeg_bin = os.path.dirname(ffmpeg_exe)
ffmpeg_root = os.path.dirname(ffmpeg_bin) ffmpeg_root = os.path.dirname(ffmpeg_bin)
ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include') ffmpeg_include_dir = os.path.join(ffmpeg_root, "include")
ffmpeg_library_dir = os.path.join(ffmpeg_root, 'lib') ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib")
gcc = distutils.spawn.find_executable('gcc') gcc = os.environ.get("CC", shutil.which("gcc"))
platform_tag = subprocess.run( platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE)
[gcc, '-print-multiarch'], stdout=subprocess.PIPE) platform_tag = platform_tag.stdout.strip().decode("utf-8")
platform_tag = platform_tag.stdout.strip().decode('utf-8')
if platform_tag: if platform_tag:
# Most probably a Debian-based distribution # Most probably a Debian-based distribution
ffmpeg_include_dir = [ ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)]
ffmpeg_include_dir, ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)]
os.path.join(ffmpeg_include_dir, platform_tag)
]
ffmpeg_library_dir = [
ffmpeg_library_dir,
os.path.join(ffmpeg_library_dir, platform_tag)
]
else: else:
ffmpeg_include_dir = [ffmpeg_include_dir] ffmpeg_include_dir = [ffmpeg_include_dir]
ffmpeg_library_dir = [ffmpeg_library_dir] ffmpeg_library_dir = [ffmpeg_library_dir]
has_ffmpeg = True
for library in ffmpeg_libraries: for library in ffmpeg_libraries:
library_found = False library_found = False
for search_path in ffmpeg_include_dir + include_dirs: for search_path in ffmpeg_include_dir + include_dirs:
full_path = os.path.join(search_path, library, '*.h') full_path = os.path.join(search_path, library, "*.h")
library_found |= len(glob.glob(full_path)) > 0 library_found |= len(glob.glob(full_path)) > 0
if not library_found: if not library_found:
print(f'{library} header files were not found, disabling ffmpeg support') print("Building torchvision without ffmpeg support")
has_ffmpeg = False print(f" {library} header files were not found, disabling ffmpeg support")
use_ffmpeg = False
else:
print("Building torchvision without ffmpeg support")
if has_ffmpeg: if use_ffmpeg:
print("ffmpeg include path: {}".format(ffmpeg_include_dir)) print("Building torchvision with ffmpeg support")
print("ffmpeg library_dir: {}".format(ffmpeg_library_dir)) print(f" ffmpeg version: {ffmpeg_version}")
print(f" ffmpeg include path: {ffmpeg_include_dir}")
print(f" ffmpeg library_dir: {ffmpeg_library_dir}")
# TorchVision base decoder + video reader # TorchVision base decoder + video reader
video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video_reader') video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp")) video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'decoder') base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder")
base_decoder_src = glob.glob( base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp"))
os.path.join(base_decoder_src_dir, "*.cpp"))
# Torchvision video API # Torchvision video API
videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video') videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video")
videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp")) videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
# exclude tests # exclude tests
base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x] base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x]
combined_src = video_reader_src + base_decoder_src + videoapi_src combined_src = video_reader_src + base_decoder_src + videoapi_src
ext_modules.append( ext_modules.append(
CppExtension( CppExtension(
'torchvision.video_reader', "torchvision.video_reader",
combined_src, combined_src,
include_dirs=[ include_dirs=[
base_decoder_src_dir, base_decoder_src_dir,
...@@ -420,29 +436,89 @@ def get_extensions(): ...@@ -420,29 +436,89 @@ def get_extensions():
videoapi_src_dir, videoapi_src_dir,
extensions_dir, extensions_dir,
*ffmpeg_include_dir, *ffmpeg_include_dir,
*include_dirs *include_dirs,
], ],
library_dirs=ffmpeg_library_dir + library_dirs, library_dirs=ffmpeg_library_dir + library_dirs,
libraries=[ libraries=[
'avcodec', "avcodec",
'avformat', "avformat",
'avutil', "avutil",
'swresample', "swresample",
'swscale', "swscale",
],
extra_compile_args=["-std=c++17"] if os.name != "nt" else ["/std:c++17", "/MP"],
extra_link_args=["-std=c++17" if os.name != "nt" else "/std:c++17"],
)
)
# Locating video codec
# CUDA_HOME should be set to the cuda root directory.
# TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
# video codec header files and libraries respectively.
video_codec_found = (
extension is CUDAExtension
and CUDA_HOME is not None
and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include])
and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include])
and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs])
)
use_video_codec = use_video_codec and video_codec_found
if (
use_video_codec
and use_ffmpeg
and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print("Building torchvision with video codec support")
gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu")
gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
cuda_libs = os.path.join(CUDA_HOME, "lib64")
cuda_inc = os.path.join(CUDA_HOME, "include")
ext_modules.append(
extension(
"torchvision.Decoder",
gpu_decoder_src,
include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs],
libraries=[
"avcodec",
"avformat",
"avutil",
"swresample",
"swscale",
"nvcuvid",
"cuda",
"cudart",
"z",
"pthread",
"dl",
"nppicc",
], ],
extra_compile_args=["-std=c++14"] if os.name != 'nt' else ['/std:c++14', '/MP'], extra_compile_args=extra_compile_args,
extra_link_args=["-std=c++14" if os.name != 'nt' else '/std:c++14'],
) )
) )
else:
print("Building torchvision without video codec support")
if (
use_video_codec
and use_ffmpeg
and not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print(
" The installed version of ffmpeg is missing the header file 'bsf.h' which is "
" required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
" `conda install -c conda-forge ffmpeg`."
)
return ext_modules return ext_modules
class clean(distutils.command.clean.clean): class clean(distutils.command.clean.clean):
def run(self): def run(self):
with open('.gitignore', 'r') as f: with open(".gitignore") as f:
ignores = f.read() ignores = f.read()
for wildcard in filter(None, ignores.split('\n')): for wildcard in filter(None, ignores.split("\n")):
for filename in glob.glob(wildcard): for filename in glob.glob(wildcard):
try: try:
os.remove(filename) os.remove(filename)
...@@ -454,37 +530,37 @@ class clean(distutils.command.clean.clean): ...@@ -454,37 +530,37 @@ class clean(distutils.command.clean.clean):
if __name__ == "__main__": if __name__ == "__main__":
print("Building wheel {}-{}".format(package_name, version)) print(f"Building wheel {package_name}-{version}")
write_version_file() write_version_file()
with open('README.rst') as f: with open("README.md") as f:
readme = f.read() readme = f.read()
setup( setup(
# Metadata # Metadata
name=package_name, name=package_name,
version=version, version=version,
author='PyTorch Core Team', author="PyTorch Core Team",
author_email='soumith@pytorch.org', author_email="soumith@pytorch.org",
url='https://github.com/pytorch/vision', url="https://github.com/pytorch/vision",
description='image and video datasets and models for torch deep learning', description="image and video datasets and models for torch deep learning",
long_description=readme, long_description=readme,
license='BSD', long_description_content_type="text/markdown",
license="BSD",
# Package info # Package info
packages=find_packages(exclude=('test',)), packages=find_packages(exclude=("test",)),
package_data={ package_data={package_name: ["*.dll", "*.dylib", "*.so"]},
package_name: ['*.dll', '*.dylib', '*.so']
},
zip_safe=False, zip_safe=False,
install_requires=requirements, install_requires=requirements,
extras_require={ extras_require={
"gdown": ["gdown>=4.7.3"],
"scipy": ["scipy"], "scipy": ["scipy"],
}, },
ext_modules=get_extensions(), ext_modules=get_extensions(),
python_requires=">=3.8",
cmdclass={ cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True), "build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean, "clean": clean,
} },
) )
"""This is a temporary module and should be removed as soon as torch.testing.assert_equal is supported."""
# TODO: remove this as soon torch.testing.assert_equal is supported
import functools
import torch.testing
__all__ = ["assert_equal"]
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
test/assets/fakedata/draw_boxes_util.png

560 Bytes | W: | H:

test/assets/fakedata/draw_boxes_util.png

855 Bytes | W: | H:

test/assets/fakedata/draw_boxes_util.png
test/assets/fakedata/draw_boxes_util.png
test/assets/fakedata/draw_boxes_util.png
test/assets/fakedata/draw_boxes_util.png
  • 2-up
  • Swipe
  • Onion skin
import os
from collections import defaultdict
from numbers import Number
from typing import Any, List
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torchvision.models._api import Weights
aten = torch.ops.aten
quantized = torch.ops.quantized
def get_shape(i):
if isinstance(i, torch.Tensor):
return i.shape
elif hasattr(i, "weight"):
return i.weight().shape
else:
raise ValueError(f"Unknown type {type(i)}")
def prod(x):
res = 1
for i in x:
res *= i
return res
def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flop = prod(input_shapes[0]) * input_shapes[-1][-1]
return flop
def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [get_shape(v) for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flop = n * c * t * d
return flop
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop
def conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def quant_conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for quantized convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
return conv_flop_count(x_shape, w_shape, out_shape, transposed=False)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = get_shape(outputs[0])
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = get_shape(outputs[1])
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
def scaled_dot_product_flash_attention_flop(inputs: List[Any], outputs: List[Any]):
# FIXME: this needs to count the flops of this kernel
# https://github.com/pytorch/pytorch/blob/207b06d099def9d9476176a1842e88636c1f714f/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L52-L267
return 0
flop_mapping = {
aten.mm: matmul_flop,
aten.matmul: matmul_flop,
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
quantized.conv2d: quant_conv_flop,
quantized.conv2d_relu: quant_conv_flop,
aten._scaled_dot_product_flash_attention: scaled_dot_product_flash_attention_flop,
}
unmapped_ops = set()
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
class FlopCounterMode(TorchDispatchMode):
def __init__(self, model=None):
self.flop_counts = defaultdict(lambda: defaultdict(int))
self.parents = ["Global"]
# global mod
if model is not None:
for name, module in dict(model.named_children()).items():
module.register_forward_pre_hook(self.enter_module(name))
module.register_forward_hook(self.exit_module(name))
def enter_module(self, name):
def f(module, inputs):
self.parents.append(name)
inputs = normalize_tuple(inputs)
out = self.create_backwards_pop(name)(*inputs)
return out
return f
def exit_module(self, name):
def f(module, inputs, outputs):
assert self.parents[-1] == name
self.parents.pop()
outputs = normalize_tuple(outputs)
return self.create_backwards_push(name)(*outputs)
return f
def create_backwards_push(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self.parents.append(name)
return grad_outs
return PushState.apply
def create_backwards_pop(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
assert self.parents[-1] == name
self.parents.pop()
return grad_outs
return PopState.apply
def __enter__(self):
self.flop_counts.clear()
super().__enter__()
def __exit__(self, *args):
# print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS")
# for mod in self.flop_counts.keys():
# print(f"Module: ", mod)
# for k, v in self.flop_counts[mod].items():
# print(f"{k}: {v / 1e9} GFLOPS")
# print()
super().__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in flop_mapping:
flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
else:
unmapped_ops.add(func_packet)
return out
def get_flops(self):
return sum(self.flop_counts["Global"].values()) / 1e9
def get_dims(module_name, height, width):
# detection models have curated input sizes
if module_name == "detection":
# we can feed a batch of 1 for detection model instead of a list of 1 image
dims = (3, height, width)
elif module_name == "video":
# hard-coding the time dimension to size 16
dims = (1, 16, 3, height, width)
else:
dims = (1, 3, height, width)
return dims
def get_ops(model: torch.nn.Module, weight: Weights, height=512, width=512):
module_name = model.__module__.split(".")[-2]
dims = get_dims(module_name=module_name, height=height, width=width)
input_tensor = torch.randn(dims)
# try:
preprocess = weight.transforms()
if module_name == "optical_flow":
inp = preprocess(input_tensor, input_tensor)
else:
# hack to enable mod(*inp) for optical_flow models
inp = [preprocess(input_tensor)]
model.eval()
flop_counter = FlopCounterMode(model)
with flop_counter:
# detection models expect a list of 3d tensors as inputs
if module_name == "detection":
model(inp)
else:
model(*inp)
flops = flop_counter.get_flops()
return round(flops, 3)
def get_file_size_mb(weight):
weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints", weight.url.split("/")[-1])
weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024
return round(weights_size_mb, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment