Unverified Commit 91b8478c authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Reduce ms_deformable_attn test memory usage (#1407)

parent b484abac
......@@ -151,7 +151,7 @@ def test_gradient_numerical(channels,
N, M, _ = 1, 2, 2
Lq, L, P = 2, 2, 2
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
shapes = torch.as_tensor([(3, 2), (2, 1)], dtype=torch.long).cuda()
level_start_index = torch.cat((shapes.new_zeros(
(1, )), shapes.prod(1).cumsum(0)[:-1]))
S = sum([(H * W).item() for H, W in shapes])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment