Commit 754aa7c8 authored by Xue Fuzhao's avatar Xue Fuzhao Committed by Frank Lee
Browse files

[NFC] polish tests/test_layers/test_3d/checks_3d/check_layer_3d.py code style (#1731)

parent ff373a11
......@@ -784,7 +784,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
......@@ -837,7 +837,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment