Commit 63e47d29 authored by jiej's avatar jiej
Browse files

[syncBN]

test update to resolve
  https://github.com/NVIDIA/apex/issues/134#issue-403525480

Using identical learning rate for both DDP with sync BN and single process BN.
The previous configure leaves the impression that sync BN requires adjusting lr
in the script, which is not true.
parent c8bc3e62
...@@ -92,6 +92,8 @@ inp_bn = inp_t.clone().requires_grad_() ...@@ -92,6 +92,8 @@ inp_bn = inp_t.clone().requires_grad_()
grad_bn = grad_output_t.clone().detach() grad_bn = grad_output_t.clone().detach()
out_bn = bn(inp_bn) out_bn = bn(inp_bn)
out_bn.backward(grad_bn) out_bn.backward(grad_bn)
for param in bn.parameters():
param.grad = param.grad / args.world_size
bn_opt = optim.SGD(bn.parameters(), lr=1.0) bn_opt = optim.SGD(bn.parameters(), lr=1.0)
sbn = apex.parallel.SyncBatchNorm(feature_size).cuda() sbn = apex.parallel.SyncBatchNorm(feature_size).cuda()
...@@ -103,7 +105,7 @@ if args.fp16: ...@@ -103,7 +105,7 @@ if args.fp16:
if args.fp64: if args.fp64:
sbn.double() sbn.double()
sbn = DDP(sbn) sbn = DDP(sbn)
sbn_opt = optim.SGD(sbn.parameters(), lr=1.0*args.world_size) sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)
inp_sbn = inp_t.clone().requires_grad_() inp_sbn = inp_t.clone().requires_grad_()
grad_sbn = grad_output_t.clone().detach() grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish]) out_sbn = sbn(inp_sbn[start:finish])
...@@ -159,11 +161,7 @@ sbn_opt.step() ...@@ -159,11 +161,7 @@ sbn_opt.step()
if args.local_rank == 0: if args.local_rank == 0:
compare("comparing bn vs sbn bias: ", bn.bias, sbn.module.bias, error) compare("comparing bn vs sbn bias: ", bn.bias, sbn.module.bias, error)
compare("comparing bn vs ref bias: ", bn.bias, bias_r.view(-1) - grad_bias_r, error)
sbn_result = compare("comparing sbn vs ref bias: ", sbn.module.bias, bias_r.view(-1) - grad_bias_r, error) and sbn_result
compare("comparing bn vs sbn weight: ", bn.weight, sbn.module.weight, error) compare("comparing bn vs sbn weight: ", bn.weight, sbn.module.weight, error)
compare("comparing bn vs ref weight: ", bn.weight, (weight_r.view(-1) - grad_weight_r), error)
sbn_result = compare("comparing sbn vs ref weight: ", sbn.module.weight, (weight_r.view(-1) - grad_weight_r), error) and sbn_result
if sbn_result: if sbn_result:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment