Commit c614a99d authored by Yanjia0's avatar Yanjia0 Committed by binmakeswell
Browse files

[NFC] polish colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255)

parent 85774f0c
from typing import Dict, Tuple
from enum import Enum
from typing import Dict, Tuple
import torch
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region
from .region_manager import RegionManager
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
class AMPOptimizer(ColossalaiOptimizer):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
......@@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer):
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict())
\ No newline at end of file
self.optim.load_state_dict(self.optim.state_dict())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment