Unverified Commit ff31b6e3 authored by Ross Wightman's avatar Ross Wightman Committed by GitHub
Browse files

Add concrete_args to feature extraction tracing. (#8393)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 305330f0
...@@ -204,6 +204,7 @@ def get_graph_node_names( ...@@ -204,6 +204,7 @@ def get_graph_node_names(
model: nn.Module, model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False, suppress_diff_warning: bool = False,
concrete_args: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
""" """
Dev utility to return node names in order of execution. See note on node Dev utility to return node names in order of execution. See note on node
...@@ -232,10 +233,13 @@ def get_graph_node_names( ...@@ -232,10 +233,13 @@ def get_graph_node_names(
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary. provided dictionary.
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. According to the `Pytorch docs
<https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_,
this parameter's API may not be guaranteed.
Returns: Returns:
tuple(list, list): a list of node names from tracing the model in tuple(list, list): a list of node names from tracing the model in
...@@ -249,9 +253,9 @@ def get_graph_node_names( ...@@ -249,9 +253,9 @@ def get_graph_node_names(
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
is_training = model.training is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs) train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train()) train_tracer.trace(model.train(), concrete_args=concrete_args)
eval_tracer = NodePathTracer(**tracer_kwargs) eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval()) eval_tracer.trace(model.eval(), concrete_args=concrete_args)
train_nodes = list(train_tracer.node_to_qualname.values()) train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values())
if not suppress_diff_warning: if not suppress_diff_warning:
...@@ -334,6 +338,7 @@ def create_feature_extractor( ...@@ -334,6 +338,7 @@ def create_feature_extractor(
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Optional[Dict[str, Any]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False, suppress_diff_warning: bool = False,
concrete_args: Optional[Dict[str, Any]] = None,
) -> fx.GraphModule: ) -> fx.GraphModule:
""" """
Creates a new graph module that returns intermediate nodes from a given Creates a new graph module that returns intermediate nodes from a given
...@@ -398,6 +403,10 @@ def create_feature_extractor( ...@@ -398,6 +403,10 @@ def create_feature_extractor(
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. According to the `Pytorch docs
<https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_,
this parameter's API may not be guaranteed.
Examples:: Examples::
...@@ -482,7 +491,7 @@ def create_feature_extractor( ...@@ -482,7 +491,7 @@ def create_feature_extractor(
# Instantiate our NodePathTracer and use that to trace the model # Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs) tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model) graph = tracer.trace(model, concrete_args=concrete_args)
name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name) graph_module = fx.GraphModule(tracer.root, graph, name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment