import hipdnn import torch def example_adamw(): model = torch.nn.Sequential( torch.nn.Linear(10, 20, device="cuda"), torch.nn.ReLU(), torch.nn.Linear(20, 1, device="cuda"), ) optimizer = hipdnn.TorchAdamW( model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.01 ) for epoch in range(10): inputs = torch.randn(32, 10, device="cuda") targets = torch.randn(32, 1, device="cuda") # 前向传播 outputs = model(inputs) loss = torch.nn.functional.mse_loss(outputs, targets) # 反向传播 optimizer.zero_grad() loss.backward() # 优化步骤 optimizer.step() # optimizer.step_batch() print(f"Epoch {epoch}, Loss: {loss.item():.4f}") if __name__ == "__main__": example_adamw()