Commit bf395ce5 authored by Kai Zhang's avatar Kai Zhang Committed by Facebook GitHub Bot
Browse files

fix for checking device type

Summary: Currently we are checking if MODEL.DEVICE is "gpu", but actually we DEVICE could also be "cuda". This diff checks if device is "cpu" instead.

Reviewed By: wat3rBro

Differential Revision: D28689547

fbshipit-source-id: 7512d32b7c08b0dcdc6487c6c2f1703655e64b19
parent 0ab6d3f1
......@@ -69,12 +69,12 @@ def _get_trainer_callbacks(cfg: CfgNode) -> List[Callback]:
callbacks.append(QuantizationAwareTraining.from_config(cfg))
return callbacks
def _get_accelerator(use_gpu: bool) -> str:
return "ddp" if use_gpu else "ddp_cpu"
def _get_accelerator(use_cpu: bool) -> str:
return "ddp_cpu" if use_cpu else "ddp"
def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> Dict[str, Any]:
use_gpu = cfg.MODEL.DEVICE.lower() == "gpu"
use_cpu = cfg.MODEL.DEVICE.lower() == "cpu"
return {
# training loop is bounded by max steps, use a large max_epochs to make
# sure max_steps is met first
......@@ -84,9 +84,9 @@ def get_trainer_params(cfg: CfgNode, num_machines: int, num_processes: int) -> D
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
"num_nodes": num_machines,
"gpus": num_processes if use_gpu else None,
"gpus": None if use_cpu else num_processes,
"num_processes": num_processes,
"accelerator": _get_accelerator(use_gpu),
"accelerator": _get_accelerator(use_cpu),
"callbacks": _get_trainer_callbacks(cfg),
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment