".circleci/unittest/linux/vscode:/vscode.git/clone" did not exist on "285753619da8bc411ab75b0a4550398d5a7697a7"
Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
This diff is collapsed.
This diff is collapsed.
......@@ -39,7 +39,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
import torch
......@@ -59,7 +59,7 @@ if TYPE_CHECKING:
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
......@@ -69,7 +69,7 @@ class Evaluator:
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
......@@ -88,7 +88,7 @@ class Evaluator:
)
with open(mapping, encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)
categorys: dict[str, dict[str, str]] = json.load(f)
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
......@@ -136,7 +136,7 @@ class Evaluator:
pbar.close()
self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
score_info = "\n".join(
[
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple
from ..data import Role
from ..extras.constants import CHOICES
......@@ -25,20 +24,19 @@ class EvalTemplate:
choice: str
answer: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r"""
def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
r"""Parse eval example.
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
output: a tuple of (prompt, response).
"""
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
) -> List[Dict[str, str]]:
r"""
Converts dataset examples to messages.
"""
self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str
) -> list[dict[str, str]]:
r"""Convert dataset examples to messages."""
messages = []
for k in range(len(support_set)):
prompt, response = self._parse_example(support_set[k])
......@@ -52,7 +50,7 @@ class EvalTemplate:
return messages
eval_templates: Dict[str, "EvalTemplate"] = {}
eval_templates: dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
......
This diff is collapsed.
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
......@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.2"
VERSION = "0.9.3.dev0"
def print_env() -> None:
......
This diff is collapsed.
This diff is collapsed.
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
......@@ -97,3 +97,7 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")
......@@ -15,7 +15,7 @@
import json
import math
import os
from typing import Any, Dict, List
from typing import Any
from transformers.trainer import TRAINER_STATE_NAME
......@@ -31,10 +31,8 @@ if is_matplotlib_available():
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
def smooth(scalars: list[float]) -> list[float]:
r"""EMA implementation according to TensorBoard."""
if len(scalars) == 0:
return []
......@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""Plot loss curves in LlamaBoard."""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
......@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
r"""Plot loss curves and saves the image."""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
data = json.load(f)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment