Commit ad98cc5f authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

removing nvtx range used for debugging (#485)

parent 325f5a0b
......@@ -8,7 +8,6 @@ class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
......@@ -64,13 +63,11 @@ class SyncBatchnormFunction(Function):
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
torch.cuda.nvtx.range_pop()
return out
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
torch.cuda.nvtx.range_push("sync_BN_bw")
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
......@@ -113,5 +110,4 @@ class SyncBatchnormFunction(Function):
if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment