test_adamw.py 844 Bytes
Newer Older
yanjl1's avatar
Initial  
yanjl1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import hipdnn
import torch


def example_adamw():
    model = torch.nn.Sequential(
        torch.nn.Linear(10, 20, device="cuda"),
        torch.nn.ReLU(),
        torch.nn.Linear(20, 1, device="cuda"),
    )
    optimizer = hipdnn.TorchAdamW(
        model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.01
    )

    for epoch in range(10):
        inputs = torch.randn(32, 10, device="cuda")
        targets = torch.randn(32, 1, device="cuda")

        # 前向传播
        outputs = model(inputs)
        loss = torch.nn.functional.mse_loss(outputs, targets)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()

        # 优化步骤
        optimizer.step()
        # optimizer.step_batch()

        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


if __name__ == "__main__":
    example_adamw()