import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import host_subplot
import json

filename = "/public/home/zhaoying1/work/Baichuan2-main/fine-tune/slurm_script/output/trainer_state.json"
with open(filename, "r") as file:
    data = json.load(file)
    
log_history_array = data.get("log_history")

step_list = []
loss_list = []

for item in log_history_array:
    step_list.append(item.get("step"))
    loss_list.append(item.get("loss"))

print("Step list:", step_list)
print("Loss list:", loss_list)



def plot_acc_loss(step_list,loss_list):
    host = host_subplot(111)  # row=1 col=1 first pic
    plt.subplots_adjust(right=0.8)  # ajust the right boundary of the plot windo
    # par1 = host.twinx()   # 共享x轴
 
    # set labels
    host.set_xlabel("steps")
    host.set_ylabel("loss")

    # plot curves
    p1, = host.plot(step_list,loss_list, label="loss")

    host.legend(loc=5)
 
    # set label color
    host.axis["left"].label.set_color(p1.get_color())
    # par1.axis["right"].label.set_color(p2.get_color())
 
    # set the range of x axis of host and y axis of par1
    plt.title("baichuan2_7bbase_ft_96c_bs1_acum1_fp16_lr2e-5")
    # plt.title("6B_ds_ft_bs32_accum1_4cards_zero3_5e-5")
    plt.draw()
    plt.show()
    # plt.savefig("6B_ds_ft_bs32_accum1_4cards_zero3_5e-5.jpg",dpi = 600)
    plt.savefig("baichuan2_7bbase_ft_96c_bs1_acum1_fp16_lr2e-5.jpg",dpi = 600)

plot_acc_loss(step_list,loss_list)
