plot.py 701 Bytes
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import matplotlib.pyplot as plt


of_losses = []
torch_losses = []

with open("of_losses.txt", "r") as lines:
    for line in lines:
        line = line.strip()
        of_losses.append(float(line))

with open("torch_losses.txt", "r") as lines:
    for line in lines:
        line = line.strip()
        torch_losses.append(float(line))


indes = [i for i in range(len(of_losses))]


plt.plot(indes, of_losses, label="oneflow")
plt.plot(indes, torch_losses, label="pytorch")

plt.xlabel("iter - axis")
# Set the y axis label of the current axis.
plt.ylabel("loss - axis")
# Set a title of the current axes.
plt.title("compare ")
# show a legend on the plot
plt.legend()
# Display a figure.
plt.show()