Unverified Commit 7990ae18 authored by Jedrzej Kosinski's avatar Jedrzej Kosinski Committed by GitHub
Browse files

Fix error when more cond masks passed in than batch size (#3353)

parent 16eabdf7
......@@ -34,7 +34,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask[:input_x.shape[0],area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment