Commit b965585d authored by xyupeng's avatar xyupeng Committed by Frank Lee
Browse files

[NFC] polish colossalai/amp/torch_amp/torch_amp.py code style (#2290)

parent d1e5bafc
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch.nn as nn
import torch.cuda.amp as torch_amp import torch.cuda.amp as torch_amp
import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
from ._grad_scaler import GradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32 from colossalai.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
class TorchAMPOptimizer(ColossalaiOptimizer): class TorchAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class which integrate Pytorch AMP with an optimizer """A wrapper class which integrate Pytorch AMP with an optimizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment