Unverified Commit 886aa327 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS perf test, super minor (#495)

parent 8c405c51
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
def get_golden_real_stats(): def get_golden_real_stats():
return { return {
"reference_speed": 650, "reference_speed": 660,
"reference_memory": 1000, "reference_memory": 1000,
"reference_loss": 0.026, "reference_loss": 0.026,
} }
......
...@@ -150,7 +150,8 @@ def train( ...@@ -150,7 +150,8 @@ def train(
if optim_type == OptimType.oss_sharded_ddp: if optim_type == OptimType.oss_sharded_ddp:
optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
model = ShardedDDP(model, optimizer) # Single node run typically, no need for reduce buckets
model = ShardedDDP(model, optimizer, reduce_buffer_size=0)
else: else:
device_ids = None if args.cpu else [rank] device_ids = None if args.cpu else [rank]
model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment