raiseRuntimeError(name+" with size {} is not contiguous"
.format(tensor.size()))
ifnottensor.is_cuda:
raiseRuntimeError(name+".is_cuda = False."
"Currently, only cuda tensors are supported.")
classFused_Weight_Norm(Function):
"""
Custom autograd function that implements weight norm, as presented in
`<https://arxiv.org/abs/1602.07868>`_,
along a tensor's slowest or
fastest dimension using fused kernel launches for the forward and backward passes.
Accepts fp32 or fp16 input; the output type will match the input type.
Within the kernels, all calculations are performed in fp32 for numerical stability, regardless
of input/output precision.
We are refactoring our fused kernels to add to Pytorch core, so that Pytorch's built-in weightnorm
will use them transparently. Please use Pytorch's built-in weightnorm implementation for now, to
future-proof your code.
"""
@staticmethod
defforward(ctx,input,g,dim=0):
"""
Args:
input(torch.cuda.FloatTensor or torch.cuda.HalfTensor): input tensor corresponding to **v** in the paper. ``input`` should be contiguous.
g(torch.cuda.FloatTensor or torch.cuda.HalfTensor): input tensor corresponding to **g** in the paper. ``g`` should be the same type as ``input``.
dim(int, optional, default=0): Dimension across which to perform weightnorm. Currently, only the first or last dimension of the input tensor is supported.
Returns:
Output tensor corresponding to **w** in the paper. Output type and precision will match
"We are in the process of adding our fused kernels to Pytorch core, "+
"so Pytorch's built-in weightnorm will use them transparently.")
@staticmethod
@once_differentiable
defbackward(ctx,grad_output):
"""
Args:
grad_output(torch.cuda.FloatTensor or torch.cuda.HalfTensor): Gradient of loss with respect to output **w**. ``grad_output`` should be contiguous for performance.
Returns:
Gradient of loss with respect to ``input`` and ``g``. The precision of these gradients will match the precision of ``grad_input``.
"""
check_contig_cuda((grad_output),("grad_output"))
savedInput,savedg=ctx.saved_tensors
savedNorms=ctx.norms
# We expect that these .contiguous() calls will be no-ops. They're present for safety.