Unverified Commit 878ba512 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #138 from NVIDIA/sbn_test_cases

[syncBN]
parents 95fe7f6a d0624f4f
......@@ -92,6 +92,10 @@ inp_bn = inp_t.clone().requires_grad_()
grad_bn = grad_output_t.clone().detach()
out_bn = bn(inp_bn)
out_bn.backward(grad_bn)
# compensating the averaging over processes done by DDP
# in order to produce mathematically equivalent result
for param in bn.parameters():
param.grad = param.grad / args.world_size
bn_opt = optim.SGD(bn.parameters(), lr=1.0)
sbn = apex.parallel.SyncBatchNorm(feature_size).cuda()
......@@ -103,7 +107,7 @@ if args.fp16:
if args.fp64:
sbn.double()
sbn = DDP(sbn)
sbn_opt = optim.SGD(sbn.parameters(), lr=1.0*args.world_size)
sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)
inp_sbn = inp_t.clone().requires_grad_()
grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish])
......@@ -159,11 +163,7 @@ sbn_opt.step()
if args.local_rank == 0:
compare("comparing bn vs sbn bias: ", bn.bias, sbn.module.bias, error)
compare("comparing bn vs ref bias: ", bn.bias, bias_r.view(-1) - grad_bias_r, error)
sbn_result = compare("comparing sbn vs ref bias: ", sbn.module.bias, bias_r.view(-1) - grad_bias_r, error) and sbn_result
compare("comparing bn vs sbn weight: ", bn.weight, sbn.module.weight, error)
compare("comparing bn vs ref weight: ", bn.weight, (weight_r.view(-1) - grad_weight_r), error)
sbn_result = compare("comparing sbn vs ref weight: ", sbn.module.weight, (weight_r.view(-1) - grad_weight_r), error) and sbn_result
if sbn_result:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment