Unverified Commit cc766aa5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] Add a torch AMP benchmark option and test job (#175)

* oss benchmark: add an --amp option
* add a circleCI test
parent 0d1f058b
...@@ -126,6 +126,11 @@ run_oss_gloo: &run_oss_gloo ...@@ -126,6 +126,11 @@ run_oss_gloo: &run_oss_gloo
command: | command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3 python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3
run_oss_amp: &run_oss_amp
- run:
name: Run OSS with Torch AMP
command: |
python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
# Jobs to run # Jobs to run
...@@ -316,6 +321,8 @@ jobs: ...@@ -316,6 +321,8 @@ jobs:
- <<: *run_oss_gloo - <<: *run_oss_gloo
- <<: *run_oss_amp
......
...@@ -140,9 +140,15 @@ def train( ...@@ -140,9 +140,15 @@ def train(
next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item() next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
) )
) )
if not args.cpu and args.amp:
# Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast():
outputs = model(batch["inputs"]) outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, batch["label"])
else:
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss.backward() loss.backward()
if optim_type == OptimType.oss_sharded_ddp: if optim_type == OptimType.oss_sharded_ddp:
...@@ -244,7 +250,8 @@ if __name__ == "__main__": ...@@ -244,7 +250,8 @@ if __name__ == "__main__":
parser.add_argument("--profile", action="store_true", default=False) parser.add_argument("--profile", action="store_true", default=False)
parser.add_argument("--cpu", action="store_true", default=False) parser.add_argument("--cpu", action="store_true", default=False)
parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101") parser.add_argument("--torchvision_model", type=str, help="Any torchvision model name (str)", default="resnet101")
parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--amp", action="store_true", default=False, help="Activate torch AMP")
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment