test_simple.py 1.63 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5

aiss's avatar
aiss committed
6
7
import torch
from pytorch_lightning import LightningModule, Trainer
aiss's avatar
aiss committed
8
from pytorch_lightning.strategies import DeepSpeedStrategy
aiss's avatar
aiss committed
9
10
11
12
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
aiss's avatar
aiss committed
13

aiss's avatar
aiss committed
14
15
16
17
18
19
20
21
22
23
24
25
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
aiss's avatar
aiss committed
26

aiss's avatar
aiss committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


def test_lightning_model():
    """Test that DeepSpeed works with a simple LightningModule and LightningDataModule."""

    model = BoringModel()
aiss's avatar
aiss committed
61
    trainer = Trainer(strategy=DeepSpeedStrategy(), max_epochs=1, precision=16, accelerator="gpu", devices=1)
aiss's avatar
aiss committed
62
    trainer.fit(model)