Gradient clipping with fused kernels (#1405)
* Gradient clipping routine with fused kernels Identical API as PyTorch. Falls back to PyTorch impl when not computing L2 norm. * Add unit test for gradient clipping * Add fp16 case to gradient clipping unit test * Tweaks to grad clipping unit test Review suggestions from @crcrpar * Debug gradient clipping tests When checking that incorrect results produce assertion errors, make sure to generate a discrepancy outside the range of numerical error.
Showing
Please register or sign in to comment