Unverified Commit c4af33b6 authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

[minor] Remove false tuples (#994)

parent ca18b50c
...@@ -33,7 +33,7 @@ def test_write_read(): ...@@ -33,7 +33,7 @@ def test_write_read():
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
ref_tensor = torch.rand((128), dtype=torch.float32) ref_tensor = torch.rand(128, dtype=torch.float32)
test_tensor = torch.zeros_like(ref_tensor) test_tensor = torch.zeros_like(ref_tensor)
assert not torch.equal(ref_tensor, test_tensor) assert not torch.equal(ref_tensor, test_tensor)
so.write(ref_tensor, f.name) so.write(ref_tensor, f.name)
...@@ -45,7 +45,7 @@ def test_ssd_handle_dispatch_fwd(): ...@@ -45,7 +45,7 @@ def test_ssd_handle_dispatch_fwd():
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((128)) orig_tensor = torch.randn(128)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0) ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True) ssd_handle.to_file(release_tensor_after_write=True)
...@@ -91,7 +91,7 @@ def test_ssd_handle_dispatch_bwd_hook(): ...@@ -91,7 +91,7 @@ def test_ssd_handle_dispatch_bwd_hook():
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0) ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True) ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones((1), requires_grad=True).cuda() one = torch.ones(1, requires_grad=True).cuda()
orig_copy = ssd_handle.data orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True) cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
...@@ -235,7 +235,7 @@ def test_ssd_flat_parameter_basic(): ...@@ -235,7 +235,7 @@ def test_ssd_flat_parameter_basic():
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32)) refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False) ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0) ssd_flat_param.set_file_params(f.name, 0)
...@@ -261,7 +261,7 @@ def test_ssd_flat_parameter_view_modify(): ...@@ -261,7 +261,7 @@ def test_ssd_flat_parameter_view_modify():
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=False) refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False) ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0) ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False ssd_flat_param.flush_on_dirty = False
...@@ -300,7 +300,7 @@ def test_ssd_flat_parameter_view_bwd(): ...@@ -300,7 +300,7 @@ def test_ssd_flat_parameter_view_bwd():
.requires_grad_() .requires_grad_()
) )
refc_param = ( refc_param = (
torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=True) torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=True)
.to("cpu") .to("cpu")
.detach() .detach()
.requires_grad_() .requires_grad_()
...@@ -318,7 +318,7 @@ def test_ssd_flat_parameter_view_bwd(): ...@@ -318,7 +318,7 @@ def test_ssd_flat_parameter_view_bwd():
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called)) grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy ssd_flat_param.data = cuda_copy
one = torch.ones((1), requires_grad=True, device=ssd_flat_param.device) one = torch.ones(1, requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1 y2 = cuda_copy + 1
...@@ -412,7 +412,7 @@ def test_ssd_flat_parameter_direct_to_file(): ...@@ -412,7 +412,7 @@ def test_ssd_flat_parameter_direct_to_file():
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32)) refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors( ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0 [refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
) )
......
...@@ -82,7 +82,7 @@ def run_one_step( ...@@ -82,7 +82,7 @@ def run_one_step(
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = _get_mlp() model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones(1) * rank)
model.to(device) model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters next(model.parameters()).requires_grad = False # Test non-trainable parameters
......
...@@ -131,7 +131,7 @@ def run_ddp_parity( ...@@ -131,7 +131,7 @@ def run_ddp_parity(
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = _get_mlp_emb(multiple_fw) model = _get_mlp_emb(multiple_fw)
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones(1) * rank)
model.to(device) model.to(device)
# Make sure that the model starts with non-trainable, so that we check for the buckets to be # Make sure that the model starts with non-trainable, so that we check for the buckets to be
...@@ -289,7 +289,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b ...@@ -289,7 +289,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_b
BATCHS = 20 BATCHS = 20
model = _get_mlp_emb() model = _get_mlp_emb()
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones(1) * rank)
model.to(device) model.to(device)
n_half_params = len(list(model.parameters())) // 2 n_half_params = len(list(model.parameters())) // 2
optim_settings = {"lr": 1e-3, "momentum": 0.99} optim_settings = {"lr": 1e-3, "momentum": 0.99}
......
...@@ -146,7 +146,7 @@ class TestSingleRank(unittest.TestCase): ...@@ -146,7 +146,7 @@ class TestSingleRank(unittest.TestCase):
# Move the model to device after OSS was constructed # Move the model to device after OSS was constructed
x.to(DEVICE) x.to(DEVICE)
x(torch.zeros((1), device=DEVICE)).backward() x(torch.zeros(1, device=DEVICE)).backward()
# Check that OSS detects that the device changed # Check that OSS detects that the device changed
o.step() o.step()
...@@ -853,7 +853,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph ...@@ -853,7 +853,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph
trunk = torch.nn.Sequential( trunk = torch.nn.Sequential(
torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, hidden) torch.nn.Linear(in_channels, hidden), torch.nn.Linear(hidden, hidden), torch.nn.Linear(hidden, hidden)
) )
trunk.register_buffer("test_buffer", torch.ones((1)) * rank) trunk.register_buffer("test_buffer", torch.ones(1) * rank)
trunk.to(device) trunk.to(device)
head = torch.nn.Linear(hidden, out_channels).to(device) head = torch.nn.Linear(hidden, out_channels).to(device)
......
...@@ -35,7 +35,7 @@ def test_apply_to_tensors(devices): ...@@ -35,7 +35,7 @@ def test_apply_to_tensors(devices):
def get_a_tensor(): def get_a_tensor():
"""Return a random tensor on random device.""" """Return a random tensor on random device."""
dev = random.choice(devices) dev = random.choice(devices)
shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10))) shape = random.choice((1, (2, 3), (4, 5, 6), (7, 8, 9, 10)))
t = torch.rand(shape).to(dev) t = torch.rand(shape).to(dev)
nonlocal expected nonlocal expected
expected += t.numel() expected += t.numel()
...@@ -45,7 +45,7 @@ def test_apply_to_tensors(devices): ...@@ -45,7 +45,7 @@ def test_apply_to_tensors(devices):
data = [1, "str"] data = [1, "str"]
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()])) data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2)))) data.append(([1], get_a_tensor(), 1, [get_a_tensor()], set((1, 2))))
od = OrderedDict() od = OrderedDict()
od["k"] = "value" od["k"] = "value"
data.append(od) data.append(od)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment