Unverified Commit 52ba090f authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Default tracer args (#5637)

* set default tracer kwargs always

* simplify code

* torchvision/models/feature_extraction.py

* Adress PR comments

* fix doc format

* fix formatting

* fix doc error
parent 56fb0bf5
...@@ -184,6 +184,23 @@ def _get_leaf_modules_for_ops() -> List[type]: ...@@ -184,6 +184,23 @@ def _get_leaf_modules_for_ops() -> List[type]:
return result return result
def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
default_autowrap_modules = (math, torchvision.ops)
default_leaf_modules = _get_leaf_modules_for_ops()
result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
result_tracer_kwargs["autowrap_modules"] = (
tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules))
if "autowrap_modules" in result_tracer_kwargs
else default_autowrap_modules
)
result_tracer_kwargs["leaf_modules"] = (
list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules))
if "leaf_modules" in result_tracer_kwargs
else default_leaf_modules
)
return result_tracer_kwargs
def get_graph_node_names( def get_graph_node_names(
model: nn.Module, model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None,
...@@ -212,7 +229,11 @@ def get_graph_node_names( ...@@ -212,7 +229,11 @@ def get_graph_node_names(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (they are eventually passed onto ``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops. By default it will be set to wrap and make leaf nodes all torchvision ops:
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
...@@ -226,14 +247,7 @@ def get_graph_node_names( ...@@ -226,14 +247,7 @@ def get_graph_node_names(
>>> model = torchvision.models.resnet18() >>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model) >>> train_nodes, eval_nodes = get_graph_node_names(model)
""" """
if tracer_kwargs is None: tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
is_training = model.training is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs) train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train()) train_tracer.trace(model.train())
...@@ -378,7 +392,10 @@ def create_feature_extractor( ...@@ -378,7 +392,10 @@ def create_feature_extractor(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (which passes them onto it's parent class ``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops. By default it will be set to wrap and make leaf nodes all torchvision ops:
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
...@@ -423,14 +440,7 @@ def create_feature_extractor( ...@@ -423,14 +440,7 @@ def create_feature_extractor(
>>> 'autowrap_functions': [leaf_function]}) >>> 'autowrap_functions': [leaf_function]})
""" """
if tracer_kwargs is None: tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
is_training = model.training is_training = model.training
if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]): if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment