amp.py 205 Bytes
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn 
import torch.cuda.amp as amp


from src.core import register
import src.misc.dist as dist 


__all__ = ['GradScaler']

GradScaler = register(amp.grad_scaler.GradScaler)