Commit 03ba8eb1 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

commented clip_grads.py.

parent 087cbff1
...@@ -38,6 +38,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm, ...@@ -38,6 +38,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
Arguments: Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm. infinity norm.
...@@ -50,6 +52,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm, ...@@ -50,6 +52,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
if isinstance(grads_for_norm, torch.Tensor):
grads_for_norm = [grads_for_norm]
# Grads. # Grads.
grads = [ p.grad.detach() for p in parameters if p.grad is not None ] grads = [ p.grad.detach() for p in parameters if p.grad is not None ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment