Unverified Commit 151d150b authored by sarunyap's avatar sarunyap Committed by GitHub
Browse files

Fix bn_addrelu's bitmask type error (#67)

This patch converts torch.cuda.LongTensor's argument of bn_addrelu's
bitmask to int to fix the type error.
parent 8f5ae436
...@@ -66,7 +66,8 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function): ...@@ -66,7 +66,8 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
if is_train: if is_train:
if IS_ROCM_PYTORCH: if IS_ROCM_PYTORCH:
nhw = x.shape[0] * x.shape[1] * x.shape[2] nhw = x.shape[0] * x.shape[1] * x.shape[2]
bitmask = torch.cuda.LongTensor(((nhw + 3) & ~3) * grid_dim_y) shape = int(((nhw + 3) & ~3) * grid_dim_y)
bitmask = torch.cuda.LongTensor(shape)
else: else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask) ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment