Commit 9dc1600b authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Facebook: Reward Function in D2Go

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/116

Reviewed By: newstzpz

Differential Revision: D30860098

fbshipit-source-id: 5c9422dd91d305193f9b43869f12423660217010
parent c2b397b1
...@@ -164,6 +164,9 @@ class TorchscriptWrapper(nn.Module): ...@@ -164,6 +164,9 @@ class TorchscriptWrapper(nn.Module):
# TODO: set int8 backend accordingly if needed # TODO: set int8 backend accordingly if needed
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
def get_wrapped_models(self):
return self.module
def load_torchscript(model_path): def load_torchscript(model_path):
extra_files = {} extra_files = {}
...@@ -261,6 +264,9 @@ class TracingAdapterModelWrapper(nn.Module): ...@@ -261,6 +264,9 @@ class TracingAdapterModelWrapper(nn.Module):
flattened_outputs = self.traced_model(*flattened_inputs) flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs) return self.outputs_schema(flattened_outputs)
def get_wrapped_models(self):
return self.traced_model
def tracing_adapter_wrap_load(old_f): def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs): def new_f(cls, save_path, **load_kwargs):
......
...@@ -16,9 +16,12 @@ def create_runner( ...@@ -16,9 +16,12 @@ def create_runner(
"""Constructs a runner instance if class is a d2go runner. Returns class """Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module. type if class is a Lightning module.
""" """
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1) if class_full_name is None:
runner_module = importlib.import_module(runner_module_name) runner_class = GeneralizedRCNNRunner
runner_class = getattr(runner_module, runner_class_name) else:
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
runner_class = getattr(runner_module, runner_class_name)
if issubclass(runner_class, LightningModule): if issubclass(runner_class, LightningModule):
# Return runner class for Lightning module since it requires config # Return runner class for Lightning module since it requires config
# to construct # to construct
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment