Unverified Commit ff31b6e3 authored by Ross Wightman's avatar Ross Wightman Committed by GitHub
Browse files

Add concrete_args to feature extraction tracing. (#8393)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 305330f0
......@@ -204,6 +204,7 @@ def get_graph_node_names(
model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False,
concrete_args: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Dev utility to return node names in order of execution. See note on node
......@@ -232,10 +233,13 @@ def get_graph_node_names(
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. According to the `Pytorch docs
<https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_,
this parameter's API may not be guaranteed.
Returns:
tuple(list, list): a list of node names from tracing the model in
......@@ -249,9 +253,9 @@ def get_graph_node_names(
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
train_tracer.trace(model.train(), concrete_args=concrete_args)
eval_tracer = NodePathTracer(**tracer_kwargs)
eval_tracer.trace(model.eval())
eval_tracer.trace(model.eval(), concrete_args=concrete_args)
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if not suppress_diff_warning:
......@@ -334,6 +338,7 @@ def create_feature_extractor(
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False,
concrete_args: Optional[Dict[str, Any]] = None,
) -> fx.GraphModule:
"""
Creates a new graph module that returns intermediate nodes from a given
......@@ -398,6 +403,10 @@ def create_feature_extractor(
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
not be treated as Proxies. According to the `Pytorch docs
<https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer.trace>`_,
this parameter's API may not be guaranteed.
Examples::
......@@ -482,7 +491,7 @@ def create_feature_extractor(
# Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph = tracer.trace(model, concrete_args=concrete_args)
name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment