Unverified Commit 790f1cdc authored by toni057's avatar toni057 Committed by GitHub
Browse files

Adding _ops and _weight_size metadata checks to tests (#6996)



* Adding _ops and _weight_size metadata checks to tests

* Fixing wrong ops value

* Changing test_schema_meta_validation to instantiate the model only once

* moving instantiating quantized models inside get_ops

* Small refactor of test_schema_meta_validation logic

* Reverting to previous ops value

* Simplifying unquantized models logic in test_schema_meta_validation

* Update test/test_extended_models.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarToni Blaslov <tblaslov@fb.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 264b83b5
import os
from collections import defaultdict
from numbers import Number
from typing import Any, List
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torchvision.models._api import Weights
aten = torch.ops.aten
quantized = torch.ops.quantized
def get_shape(i):
if isinstance(i, torch.Tensor):
return i.shape
elif hasattr(i, "weight"):
return i.weight().shape
else:
raise ValueError(f"Unknown type {type(i)}")
def prod(x):
res = 1
for i in x:
res *= i
return res
def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flop = prod(input_shapes[0]) * input_shapes[-1][-1]
return flop
def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [get_shape(v) for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flop = n * c * t * d
return flop
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop
def conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def quant_conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for quantized convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
return conv_flop_count(x_shape, w_shape, out_shape, transposed=False)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = get_shape(outputs[0])
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = get_shape(outputs[1])
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
flop_mapping = {
aten.mm: matmul_flop,
aten.matmul: matmul_flop,
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
quantized.conv2d: quant_conv_flop,
quantized.conv2d_relu: quant_conv_flop,
}
unmapped_ops = set()
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
class FlopCounterMode(TorchDispatchMode):
def __init__(self, model=None):
self.flop_counts = defaultdict(lambda: defaultdict(int))
self.parents = ["Global"]
# global mod
if model is not None:
for name, module in dict(model.named_children()).items():
module.register_forward_pre_hook(self.enter_module(name))
module.register_forward_hook(self.exit_module(name))
def enter_module(self, name):
def f(module, inputs):
self.parents.append(name)
inputs = normalize_tuple(inputs)
out = self.create_backwards_pop(name)(*inputs)
return out
return f
def exit_module(self, name):
def f(module, inputs, outputs):
assert self.parents[-1] == name
self.parents.pop()
outputs = normalize_tuple(outputs)
return self.create_backwards_push(name)(*outputs)
return f
def create_backwards_push(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self.parents.append(name)
return grad_outs
return PushState.apply
def create_backwards_pop(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
assert self.parents[-1] == name
self.parents.pop()
return grad_outs
return PopState.apply
def __enter__(self):
self.flop_counts.clear()
super().__enter__()
def __exit__(self, *args):
# print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS")
# for mod in self.flop_counts.keys():
# print(f"Module: ", mod)
# for k, v in self.flop_counts[mod].items():
# print(f"{k}: {v / 1e9} GFLOPS")
# print()
super().__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in flop_mapping:
flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
else:
unmapped_ops.add(func_packet)
return out
def get_flops(self):
return sum(self.flop_counts["Global"].values()) / 1e9
def get_dims(module_name, height, width):
# detection models have curated input sizes
if module_name == "detection":
# we can feed a batch of 1 for detection model instead of a list of 1 image
dims = (3, height, width)
elif module_name == "video":
# hard-coding the time dimension to size 16
dims = (1, 16, 3, height, width)
else:
dims = (1, 3, height, width)
return dims
def get_ops(model: torch.nn.Module, weight: Weights, height=512, width=512):
module_name = model.__module__.split(".")[-2]
dims = get_dims(module_name=module_name, height=height, width=width)
input_tensor = torch.randn(dims)
# try:
preprocess = weight.transforms()
if module_name == "optical_flow":
inp = preprocess(input_tensor, input_tensor)
else:
# hack to enable mod(*inp) for optical_flow models
inp = [preprocess(input_tensor)]
model.eval()
flop_counter = FlopCounterMode(model)
with flop_counter:
# detection models expect a list of 3d tensors as inputs
if module_name == "detection":
model(inp)
else:
model(*inp)
flops = flop_counter.get_flops()
return round(flops, 3)
def get_weight_size_mb(weight):
weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints", weight.url.split("/")[-1])
weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024
return round(weights_size_mb, 3)
...@@ -4,11 +4,11 @@ import os ...@@ -4,11 +4,11 @@ import os
import pytest import pytest
import test_models as TM import test_models as TM
import torch import torch
from common_extended_utils import get_ops, get_weight_size_mb
from torchvision import models from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum from torchvision.models._api import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface from torchvision.models._utils import handle_legacy_interface
run_if_test_with_extended = pytest.mark.skipif( run_if_test_with_extended = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1", os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.", reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
...@@ -131,6 +131,22 @@ def test_naming_conventions(model_fn): ...@@ -131,6 +131,22 @@ def test_naming_conventions(model_fn):
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
detection_models_input_dims = {
"fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
"fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
"fasterrcnn_resnet50_fpn": (800, 800),
"fasterrcnn_resnet50_fpn_v2": (800, 800),
"fcos_resnet50_fpn": (800, 800),
"keypointrcnn_resnet50_fpn": (1333, 1333),
"maskrcnn_resnet50_fpn": (800, 800),
"maskrcnn_resnet50_fpn_v2": (800, 800),
"retinanet_resnet50_fpn": (800, 800),
"retinanet_resnet50_fpn_v2": (800, 800),
"ssd300_vgg16": (300, 300),
"ssdlite320_mobilenet_v3_large": (320, 320),
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn", "model_fn",
TM.list_model_fns(models) TM.list_model_fns(models)
...@@ -182,7 +198,7 @@ def test_schema_meta_validation(model_fn): ...@@ -182,7 +198,7 @@ def test_schema_meta_validation(model_fn):
pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
problematic_weights = {} problematic_weights = {}
incorrect_params = [] incorrect_meta = []
bad_names = [] bad_names = []
for w in weights_enum: for w in weights_enum:
actual_fields = set(w.meta.keys()) actual_fields = set(w.meta.keys())
...@@ -195,24 +211,45 @@ def test_schema_meta_validation(model_fn): ...@@ -195,24 +211,45 @@ def test_schema_meta_validation(model_fn):
unsupported_fields = set(w.meta.keys()) - permitted_fields unsupported_fields = set(w.meta.keys()) - permitted_fields
if missing_fields or unsupported_fields: if missing_fields or unsupported_fields:
problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields} problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
if w == weights_enum.DEFAULT:
if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
if module_name == "quantization": if module_name == "quantization":
# parameters() count doesn't work well with quantization, so we check against the non-quantized # parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized") unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"): if unquantized_w is not None:
incorrect_params.append(w) if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_meta.append((w, "num_params"))
# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
# instead
if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
incorrect_meta.append((w, "_ops"))
else: else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): # loading the model and using it for parameter and ops verification
incorrect_params.append(w) model = model_fn(weights=w)
else:
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"): if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): incorrect_meta.append((w, "num_params"))
incorrect_params.append(w)
kwargs = {}
if model_name in detection_models_input_dims:
# detection models have non default height and width
height, width = detection_models_input_dims[model_name]
kwargs = {"height": height, "width": width}
calculated_ops = get_ops(model=model, weight=w, **kwargs)
if calculated_ops != w.meta["_ops"]:
incorrect_meta.append((w, "_ops"))
if not w.name.isupper(): if not w.name.isupper():
bad_names.append(w) bad_names.append(w)
if get_weight_size_mb(w) != w.meta.get("_weight_size"):
incorrect_meta.append((w, "_weight_size"))
assert not problematic_weights assert not problematic_weights
assert not incorrect_params assert not incorrect_meta
assert not bad_names assert not bad_names
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment