Commit 009a3e71 authored by Tri Dao's avatar Tri Dao
Browse files

[Training] Fix lightning _PATH import

parent 993d1244
...@@ -9,8 +9,13 @@ from torch.distributed.optim import ZeroRedundancyOptimizer ...@@ -9,8 +9,13 @@ from torch.distributed.optim import ZeroRedundancyOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import _PATH try: # pytorch_lightning <= 1.7
# from lightning_lite.utilities.types import _PATH from pytorch_lightning.utilities.types import _PATH
except ImportError: # pytorch_lightning >= 1.8
try:
from lightning_lite.utilities.types import _PATH
except ImportError: # pytorch_lightning >= 1.9
from lightning_fabric.utilities.types import _PATH
# Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get # Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get
......
...@@ -13,9 +13,14 @@ from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam ...@@ -13,9 +13,14 @@ from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import _PATH
# from lightning_lite.utilities.types import _PATH
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
try: # pytorch_lightning <= 1.7
from pytorch_lightning.utilities.types import _PATH
except ImportError: # pytorch_lightning >= 1.8
try:
from lightning_lite.utilities.types import _PATH
except ImportError: # pytorch_lightning >= 1.9
from lightning_fabric.utilities.types import _PATH
class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment