Unverified Commit 051ed9e6 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Unpin `pytorch_lightning<1.2` (#3598)

parent b7f374ce
......@@ -6,7 +6,7 @@ torch == 1.6.0+cpu ; sys_platform != "darwin"
torch == 1.6.0 ; sys_platform == "darwin"
torchvision == 0.7.0+cpu ; sys_platform != "darwin"
torchvision == 0.7.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.1.1, < 1.2
pytorch-lightning >= 1.1.1
onnx
peewee
graphviz
......@@ -165,13 +165,18 @@ class _SupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class _AccuracyWithLogits(pl.metrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': pl.metrics.Accuracy},
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
......
......@@ -31,7 +31,7 @@ jobs:
python3 -m pip install scikit-learn==0.24.1
python3 -m pip install torchvision==0.7.0
python3 -m pip install torch==1.6.0
python3 -m pip install 'pytorch-lightning>=1.1.1,<1.2'
python3 -m pip install 'pytorch-lightning>=1.1.1'
python3 -m pip install keras==2.1.6
python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python3 -m pip install thop
......
......@@ -28,7 +28,7 @@ jobs:
python -m pip install scikit-learn==0.24.1
python -m pip install keras==2.1.6
python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install 'pytorch-lightning>=1.1.1,<1.2'
python -m pip install 'pytorch-lightning>=1.1.1'
python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
displayName: Install extra dependencies
......
......@@ -3,6 +3,7 @@ import pytest
import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.runtime.platform.test
import pytorch_lightning
import torch
import torch.nn as nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment