import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import host_subplot import json filename = "/public/home/zhaoying1/work/Baichuan2-main/fine-tune/slurm_script/output/trainer_state.json" with open(filename, "r") as file: data = json.load(file) log_history_array = data.get("log_history") step_list = [] loss_list = [] for item in log_history_array: step_list.append(item.get("step")) loss_list.append(item.get("loss")) print("Step list:", step_list) print("Loss list:", loss_list) def plot_acc_loss(step_list,loss_list): host = host_subplot(111) # row=1 col=1 first pic plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot windo # par1 = host.twinx() # 共享x轴 # set labels host.set_xlabel("steps") host.set_ylabel("loss") # plot curves p1, = host.plot(step_list,loss_list, label="loss") host.legend(loc=5) # set label color host.axis["left"].label.set_color(p1.get_color()) # par1.axis["right"].label.set_color(p2.get_color()) # set the range of x axis of host and y axis of par1 plt.title("baichuan2_7bbase_ft_96c_bs1_acum1_fp16_lr2e-5") # plt.title("6B_ds_ft_bs32_accum1_4cards_zero3_5e-5") plt.draw() plt.show() # plt.savefig("6B_ds_ft_bs32_accum1_4cards_zero3_5e-5.jpg",dpi = 600) plt.savefig("baichuan2_7bbase_ft_96c_bs1_acum1_fp16_lr2e-5.jpg",dpi = 600) plot_acc_loss(step_list,loss_list)