Commit b9336b1e authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix use of multi_tensor_l2norm, remove test using deprecated syntax

parent 47da14a0
...@@ -263,10 +263,11 @@ def post_backward_with_master_weights_FusedAdam(self, scaler): ...@@ -263,10 +263,11 @@ def post_backward_with_master_weights_FusedAdam(self, scaler):
norm_groups = [] norm_groups = []
skip = False skip = False
for grad_group in stash.grads: for grad_group in stash.grads:
norm = multi_tensor_applier( norm, _ = multi_tensor_applier(
stash.multi_tensor_l2norm, stash.multi_tensor_l2norm,
stash.dummy_overflow_buf, stash.dummy_overflow_buf,
[grad_group]) [grad_group],
False)
# Still syncing here for now. # Still syncing here for now.
norm = float(norm) norm = float(norm)
norm_groups.append(norm) norm_groups.append(norm)
......
...@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase): ...@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase):
fn = lambda x: x.sum() fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=False)
common_init(self)
def test_disabled_linear(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
input_shape = (self.b, self.h)
for fn in [m, f]:
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), FLOAT)
y.sum().backward()
self.assertEqual(x.grad.type(), FLOAT)
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
self.assertRaises(RuntimeError, fn, x)
# TODO: maybe more tests on disabled casting? # TODO: maybe more tests on disabled casting?
if __name__ == '__main__': if __name__ == '__main__':
......
#!/bin/bash #!/bin/bash
DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/" # DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
# DATADIR="/opt/home/apex/examples/imagenet/" # DATADIR="/opt/home/apex/examples/imagenet/"
cp ../common/* . cp ../common/* .
bash run_test.sh single_gpu $1 $DATADIR yes bash run_test.sh single_gpu $1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment