"vscode:/vscode.git/clone" did not exist on "6060ff4fc2659f7be4480fa03188e8682b55f1ce"
ploting.py 3.09 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
17
import json
import math
import os
chenych's avatar
chenych committed
18
from typing import Any
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
19
20
21

from transformers.trainer import TRAINER_STATE_NAME

luopl's avatar
luopl committed
22
from . import logging
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
23
24
25
26
from .packages import is_matplotlib_available


if is_matplotlib_available():
chenych's avatar
chenych committed
27
    import matplotlib.figure
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
    import matplotlib.pyplot as plt


luopl's avatar
luopl committed
31
logger = logging.get_logger(__name__)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33


chenych's avatar
chenych committed
34
35
def smooth(scalars: list[float]) -> list[float]:
    r"""EMA implementation according to TensorBoard."""
chenych's avatar
chenych committed
36
37
38
    if len(scalars) == 0:
        return []

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39
    last = scalars[0]
chenych's avatar
chenych committed
40
    smoothed = []
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
41
42
43
44
45
46
47
48
    weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5)  # a sigmoid function
    for next_val in scalars:
        smoothed_val = last * weight + (1 - weight) * next_val
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed


chenych's avatar
chenych committed
49
50
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
    r"""Plot loss curves in LlamaBoard."""
chenych's avatar
chenych committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    plt.close("all")
    plt.switch_backend("agg")
    fig = plt.figure()
    ax = fig.add_subplot(111)
    steps, losses = [], []
    for log in trainer_log:
        if log.get("loss", None):
            steps.append(log["current_steps"])
            losses.append(log["loss"])

    ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
    ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
    ax.legend()
    ax.set_xlabel("step")
    ax.set_ylabel("loss")
    return fig


chenych's avatar
chenych committed
69
70
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
    r"""Plot loss curves and saves the image."""
chenych's avatar
chenych committed
71
    plt.switch_backend("agg")
luopl's avatar
luopl committed
72
    with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
73
74
75
76
77
78
79
80
81
82
        data = json.load(f)

    for key in keys:
        steps, metrics = [], []
        for i in range(len(data["log_history"])):
            if key in data["log_history"][i]:
                steps.append(data["log_history"][i]["step"])
                metrics.append(data["log_history"][i][key])

        if len(metrics) == 0:
luopl's avatar
luopl committed
83
            logger.warning_rank0(f"No metric {key} to plot.")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
84
85
86
87
88
            continue

        plt.figure()
        plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
        plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
luopl's avatar
luopl committed
89
        plt.title(f"training {key} of {save_dictionary}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
90
91
92
93
94
95
        plt.xlabel("step")
        plt.ylabel(key)
        plt.legend()
        figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
        plt.savefig(figure_path, format="png", dpi=100)
        print("Figure saved at:", figure_path)