Commit 9dc1600b authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Facebook: Reward Function in D2Go

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/116

Reviewed By: newstzpz

Differential Revision: D30860098

fbshipit-source-id: 5c9422dd91d305193f9b43869f12423660217010
parent c2b397b1
......@@ -164,6 +164,9 @@ class TorchscriptWrapper(nn.Module):
# TODO: set int8 backend accordingly if needed
return self.module(*args, **kwargs)
def get_wrapped_models(self):
return self.module
def load_torchscript(model_path):
extra_files = {}
......@@ -261,6 +264,9 @@ class TracingAdapterModelWrapper(nn.Module):
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
def get_wrapped_models(self):
return self.traced_model
def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs):
......
......@@ -16,6 +16,9 @@ def create_runner(
"""Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module.
"""
if class_full_name is None:
runner_class = GeneralizedRCNNRunner
else:
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment