"vscode:/vscode.git/clone" did not exist on "78700893eeb5455cfd4e8fede7d236ae521f20f4"
test_batchnorm1d.py 484 Bytes
Newer Older
jjsjann123's avatar
jjsjann123 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import apex

model = apex.parallel.SyncBatchNorm(4).cuda()
model.weight.data.uniform_()
model.bias.data.uniform_()
data = torch.rand((8,4)).cuda()

model_ref = torch.nn.BatchNorm1d(4).cuda()
model_ref.load_state_dict(model.state_dict())
data_ref = data.clone()

output = model(data)
output_ref = model_ref(data_ref)

assert(output.allclose(output_ref))
assert(model.running_mean.allclose(model_ref.running_mean))
assert(model.running_var.allclose(model_ref.running_var))