"docs/source/api/vscode:/vscode.git/clone" did not exist on "d6eecf90a1fc258de3c494209ea89141c2f4bfbe"
Commit c7bd7dfe authored by Tao Xu's avatar Tao Xu Committed by Facebook GitHub Bot
Browse files

enable the diffusion visualization evaluators to run on multiple datasets

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/527

- Add model.reset_generation_counter() to enable the diffusion visualization evaluators to run on multiple test datasets.
  - Before this fix, the visualization evaluators will only run on the 1st test dataset since self.generation_counter will set to <0 after running on the 1st test datasaet. Thus the visualization evaluators will skip for all the other test sets since self.generation_counter < 0.
- Use the ddim for upsampler by default for better results

Reviewed By: zechenghe

Differential Revision: D45058672

fbshipit-source-id: 2f7919bf6ecd2e5f6f242ce3e7891cb3dc8d6af4
parent d032c02c
...@@ -339,7 +339,13 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -339,7 +339,13 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
return d2_build_lr_scheduler(cfg, optimizer) return d2_build_lr_scheduler(cfg, optimizer)
def _create_evaluators( def _create_evaluators(
self, cfg, dataset_name, output_folder, train_iter, model_tag self,
cfg,
dataset_name,
output_folder,
train_iter,
model_tag,
model=None,
): ):
evaluator = self.get_evaluator(cfg, dataset_name, output_folder=output_folder) evaluator = self.get_evaluator(cfg, dataset_name, output_folder=output_folder)
...@@ -400,7 +406,14 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -400,7 +406,14 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
data_loader = self.build_detection_test_loader(cfg, dataset_name) data_loader = self.build_detection_test_loader(cfg, dataset_name)
evaluator = self._create_evaluators( evaluator = self._create_evaluators(
cfg, dataset_name, output_folder, train_iter, model_tag cfg,
dataset_name,
output_folder,
train_iter,
model_tag,
model.module
if isinstance(model, nn.parallel.DistributedDataParallel)
else model,
) )
results_per_dataset = inference_on_dataset(model, data_loader, evaluator) results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment