"csrc/vscode:/vscode.git/clone" did not exist on "fc1b13945c8021c446aabd7b9207286b2239a899"
Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
......@@ -20,7 +20,7 @@ import os
import sys
import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional
import torch
from accelerate.utils import DistributedDataParallelKwargs
......@@ -62,9 +62,7 @@ logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
r"""Inherit PPOTrainer."""
def __init__(
self,
......@@ -72,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]],
callbacks: Optional[list["TrainerCallback"]],
model: "AutoModelForCausalLMWithValueHead",
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
......@@ -187,9 +185,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(BAdamCallback)
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
r"""Implement training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer."""
if resume_from_checkpoint is not None:
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
......@@ -221,9 +217,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info_rank0(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
f" Total train batch size (w. parallel, buffer, distributed & accumulation) = {total_train_batch_size:,}"
)
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
......@@ -247,9 +241,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.tokenizer.padding_side = "right" # change padding side
queries, responses, rewards = [], [], []
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch_queries, mini_batch_responses = self.get_inputs(
batch[idx : idx + self.config.mini_batch_size]
)
mini_batch = {
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
}
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
queries.extend(mini_batch_queries)
responses.extend(mini_batch_responses)
......@@ -339,21 +335,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return lr_scheduler
@torch.no_grad()
def get_inputs(self, batch: Dict[str, "torch.Tensor"]) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]:
r"""
Generates model's responses given queries.
"""
def get_inputs(self, batch: dict[str, "torch.Tensor"]) -> tuple[list["torch.Tensor"], list["torch.Tensor"]]:
r"""Generate model's responses given queries."""
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: "torch.Tensor" = unwrapped_model.generate(
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
......@@ -381,11 +375,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@torch.no_grad()
def get_rewards(
self,
queries: List["torch.Tensor"],
responses: List["torch.Tensor"],
) -> List["torch.Tensor"]:
r"""
Computes scores using given reward model.
queries: list["torch.Tensor"],
responses: list["torch.Tensor"],
) -> list["torch.Tensor"]:
r"""Compute scores using given reward model.
Both inputs and outputs are put on CPU.
"""
......@@ -394,8 +387,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="reward")
......@@ -404,7 +397,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
......@@ -419,12 +412,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
model: "AutoModelForCausalLMWithValueHead",
queries: "torch.Tensor",
responses: "torch.Tensor",
model_inputs: Dict[str, Any],
model_inputs: dict[str, Any],
return_logits: bool = False,
response_masks: Optional["torch.Tensor"] = None,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""
Calculates model outputs in multiple batches.
) -> tuple["torch.Tensor", Optional["torch.Tensor"], "torch.Tensor", "torch.Tensor"]:
r"""Calculate model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
......@@ -483,8 +475,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
@override
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
r"""Save model checkpoint.
Subclass and override to inject custom behavior.
"""
......@@ -508,5 +499,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.save_checkpoint(output_dir)
elif self.args.should_save:
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
self._save(output_dir, state_dict=unwrapped_model.state_dict())
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
......@@ -37,7 +37,7 @@ def run_ppo(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
......@@ -53,7 +53,7 @@ def run_ppo(
reward_model = create_reward_model(model, model_args, finetuning_args)
# Initialize our Trainer
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
ppo_trainer: CustomPPOTrainer = CustomPPOTrainer(
model_args=model_args,
training_args=training_args,
finetuning_args=finetuning_args,
......
......@@ -31,9 +31,7 @@ if TYPE_CHECKING:
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
"""
r"""Inherit Trainer for custom optimizer."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
......@@ -72,3 +70,7 @@ class CustomTrainer(Trainer):
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs)
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
......@@ -16,7 +16,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
......@@ -38,7 +38,7 @@ def run_pt(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
......@@ -26,11 +26,9 @@ if TYPE_CHECKING:
@dataclass
class ComputeAccuracy:
r"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
r"""Compute reward accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[Dict[str, float]]:
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
......@@ -41,7 +39,7 @@ class ComputeAccuracy:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
......@@ -18,7 +18,7 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
from transformers import Trainer
......@@ -41,9 +41,7 @@ logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer):
r"""
Inherits Trainer to compute pairwise loss.
"""
r"""Inherits Trainer to compute pairwise loss."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
......@@ -88,10 +86,9 @@ class PairwiseTrainer(Trainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Compute pairwise loss. The first n examples are chosen and the last n examples are rejected.
Subclass and override to inject custom behavior.
......@@ -113,8 +110,7 @@ class PairwiseTrainer(Trainer):
return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""
Saves model predictions to `output_dir`.
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
......@@ -126,7 +122,7 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
res: list[str] = []
for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss
......@@ -37,7 +37,7 @@ def run_rm(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
......
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
......@@ -17,7 +17,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
......@@ -37,17 +37,15 @@ if is_jieba_available():
if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
if is_rouge_available():
from rouge_chinese import Rouge
from rouge_chinese import Rouge # type: ignore
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
r"""Compute the token with the largest likelihood to reduce memory footprint."""
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0]
......@@ -62,11 +60,9 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@dataclass
class ComputeAccuracy:
r"""
Computes accuracy and supports `batch_eval_metrics`.
"""
r"""Compute accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[Dict[str, float]]:
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
......@@ -77,7 +73,7 @@ class ComputeAccuracy:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
......@@ -90,15 +86,14 @@ class ComputeAccuracy:
@dataclass
class ComputeSimilarity:
r"""
Computes text similarity scores and supports `batch_eval_metrics`.
r"""Compute text similarity scores and support `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""
tokenizer: "PreTrainedTokenizer"
def _dump(self) -> Optional[Dict[str, float]]:
def _dump(self) -> Optional[dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
......@@ -109,7 +104,7 @@ class ComputeSimilarity:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
......@@ -18,7 +18,7 @@
import json
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
......@@ -44,23 +44,24 @@ logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.
"""
r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE."""
def __init__(
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
gen_kwargs: Optional[Dict[str, Any]] = None,
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs)
if processor is not None:
self.model_accepts_loss_kwargs = False
self.finetuning_args = finetuning_args
if gen_kwargs is not None:
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
......@@ -95,17 +96,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super()._get_train_sampler()
@override
def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs)
@override
def prediction_step(
self,
model: "torch.nn.Module",
inputs: Dict[str, Union["torch.Tensor", Any]],
inputs: dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
ignore_keys: Optional[list[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Removes the prompt part in the generated tokens.
) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Remove the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
......@@ -126,8 +130,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""
Saves model predictions to `output_dir`.
r"""Save model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
......@@ -43,7 +43,7 @@ def run_sft(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
from peft import PeftModel
......@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from ..data.data_utils import DatasetModule
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None:
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: list[str] = []) -> None:
state_dict_a = model_a.state_dict()
state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
......@@ -43,7 +43,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
linear_modules, extra_modules = set(), set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
......@@ -83,7 +83,7 @@ def load_reference_model(
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device
)
if not is_trainable:
......@@ -111,7 +111,7 @@ def load_dataset_module(**kwargs) -> "DatasetModule":
def patch_valuehead_model() -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: dict[str, "torch.Tensor"]) -> None:
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
......@@ -21,7 +21,7 @@ import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from transformers import Trainer
......@@ -63,12 +63,10 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
r"""A dummy optimizer used for the GaLore or APOLLO algorithm."""
def __init__(
self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
self, lr: float = 1e-3, optimizer_dict: Optional[dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None
) -> None:
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
......@@ -112,8 +110,7 @@ def create_modelcard_and_push(
def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
r"""Create reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
......@@ -148,9 +145,7 @@ def create_ref_model(
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["AutoModelForCausalLMWithValueHead"]:
r"""
Creates reward model for PPO training.
"""
r"""Create reward model for PPO training."""
if finetuning_args.reward_model_type == "api":
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
......@@ -189,10 +184,8 @@ def create_reward_model(
return reward_model
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
def _get_decay_parameter_names(model: "PreTrainedModel") -> list[str]:
r"""Return a list of names of parameters with weight decay. (weights in non-layernorm layers)."""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
......@@ -208,7 +201,7 @@ def _create_galore_optimizer(
else:
galore_targets = finetuning_args.galore_target
galore_params: List["torch.nn.Parameter"] = []
galore_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters():
......@@ -224,7 +217,7 @@ def _create_galore_optimizer(
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
trainable_params: list[torch.nn.Parameter] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
......@@ -251,7 +244,7 @@ def _create_galore_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
......@@ -296,7 +289,7 @@ def _create_apollo_optimizer(
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = []
apollo_params: list[torch.nn.Parameter] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
......@@ -315,7 +308,7 @@ def _create_apollo_optimizer(
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
trainable_params: list[torch.nn.Parameter] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
......@@ -338,7 +331,7 @@ def _create_apollo_optimizer(
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
optimizer_dict: dict[torch.Tensor, torch.optim.Optimizer] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
......@@ -380,7 +373,7 @@ def _create_loraplus_optimizer(
embedding_lr = finetuning_args.loraplus_lr_embedding
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
param_dict: dict[str, list[torch.nn.Parameter]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
......@@ -522,9 +515,25 @@ def create_custom_scheduler(
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
if training_args.lr_scheduler_type == "warmup_stable_decay":
num_warmup_steps = training_args.get_warmup_steps(num_training_steps)
remaining_steps = num_training_steps - num_warmup_steps
num_stable_steps = remaining_steps // 3 # use 1/3 for stable by default
num_decay_steps = remaining_steps - num_stable_steps
scheduler_kwargs = training_args.lr_scheduler_kwargs or {}
default_kwargs = {
"num_stable_steps": num_stable_steps,
"num_decay_steps": num_decay_steps,
}
for key, value in default_kwargs.items():
if key not in scheduler_kwargs:
scheduler_kwargs[key] = value
training_args.lr_scheduler_kwargs = scheduler_kwargs
if optimizer is not None and isinstance(optimizer, DummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {}
scheduler_dict: dict[torch.nn.Parameter, torch.optim.lr_scheduler.LRScheduler] = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
......@@ -544,13 +553,13 @@ def create_custom_scheduler(
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]:
r"""
Computes the log probabilities of the given labels under the given logits.
) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
......@@ -564,12 +573,10 @@ def get_batch_logps(
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
r"""Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
......@@ -585,15 +592,22 @@ def nested_detach(
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
r"""Get the callback for logging to SwanLab."""
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
if finetuning_args.swanlab_lark_webhook_url is not None:
from swanlab.plugin.notification import LarkCallback # type: ignore
lark_callback = LarkCallback(
webhook_url=finetuning_args.swanlab_lark_webhook_url,
secret=finetuning_args.swanlab_lark_secret,
)
swanlab.register_callbacks([lark_callback])
class SwanLabCallbackExtension(SwanLabCallback):
def setup(self, args: "TrainingArguments", state: "TrainerState", model: "PreTrainedModel", **kwargs):
if not state.is_world_process_zero:
......@@ -624,7 +638,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
train_loop_config: dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
......
......@@ -14,7 +14,7 @@
import os
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.distributed as dist
......@@ -38,6 +38,7 @@ from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
import ray
from ray.train.huggingface.transformers import RayTrainReportCallback
......@@ -48,9 +49,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _training_function(config: Dict[str, Any]) -> None:
def _training_function(config: dict[str, Any]) -> None:
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
callbacks: list[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
......@@ -77,6 +78,9 @@ def _training_function(config: Dict[str, Any]) -> None:
else:
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
if is_ray_available() and ray.is_initialized():
return # if ray is intialized it will destroy the process group on return
try:
if dist.is_initialized():
dist.destroy_process_group()
......@@ -84,7 +88,7 @@ def _training_function(config: Dict[str, Any]) -> None:
logger.warning(f"Failed to destroy process group: {e}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
def run_exp(args: Optional[dict[str, Any]] = None, callbacks: Optional[list["TrainerCallback"]] = None) -> None:
args = read_args(args)
if "-h" in args or "--help" in args:
get_train_args(args)
......@@ -103,7 +107,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
def export_model(args: Optional[dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)
if model_args.export_dir is None:
......
......@@ -14,7 +14,8 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
......@@ -37,15 +38,12 @@ if is_gradio_available():
def _escape_html(text: str) -> str:
r"""
Escapes HTML characters.
"""
r"""Escape HTML characters."""
return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
r"""
Post-processes the response text.
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
r"""Post-process the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
......@@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None
self.engine: Optional[BaseEngine] = None
if not lazy_init: # read arguments from command line
super().__init__()
......@@ -124,6 +122,7 @@ class WebChatModel(ChatModel):
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
vllm_enforce_eager=True,
trust_remote_code=True,
)
......@@ -160,14 +159,13 @@ class WebChatModel(ChatModel):
@staticmethod
def append(
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
role: str,
query: str,
escape_html: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
r"""Add the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
......@@ -180,8 +178,8 @@ class WebChatModel(ChatModel):
def stream(
self,
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
lang: str,
system: str,
tools: str,
......@@ -193,9 +191,8 @@ class WebChatModel(ChatModel):
temperature: float,
skip_special_tokens: bool,
escape_html: bool,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
......
......@@ -17,7 +17,7 @@ import os
import signal
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
......@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
r"""Abort the processes recursively in a bottom-up way."""
try:
children = Process(pid).children()
if children:
......@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
r"""Get the path to saved model checkpoints."""
if os.path.sep in paths[-1]:
logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1]
......@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
r"""Get the path to user config."""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
r"""Load user config if exists."""
try:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
......@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r"""
Saves user config.
"""
r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
......@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str:
r"""
Gets the model path according to the model name.
"""
r"""Get the model path according to the model name."""
user_config = load_config()
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
if (
use_modelscope()
......@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat/distill/instruct model.
"""
r"""Get the template name if the model is a chat/distill/instruct model."""
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_time() -> str:
r"""
Gets current date and time.
"""
r"""Get current date and time."""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def is_multimodal(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
r"""Judge if the model is a vision language model."""
return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
r"""
Loads dataset_info.json.
"""
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
r"""Load dataset_info.json."""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
......@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads the training configuration from config path.
"""
def load_args(config_path: str) -> Optional[dict[str, Any]]:
r"""Load the training configuration from config path."""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
......@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
r"""Save the training configuration to config path."""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
def gen_cmd(args: dict[str, Any]) -> str:
r"""Generate CLI commands for previewing."""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
......@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
def save_cmd(args: dict[str, Any]) -> str:
r"""Save CLI commands to launch training."""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
......@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
r"""Get scores after evaluation."""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
......@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import json
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING
from ...data import Role
from ...extras.packages import is_gradio_available
......@@ -31,9 +31,7 @@ if TYPE_CHECKING:
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
r"""Check if the json schema is valid."""
try:
tools = json.loads(text)
if tools:
......@@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
) -> tuple["Component", "Component", dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
......
......@@ -14,7 +14,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
......@@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""
Checks if the dataset is a local dataset.
"""
r"""Check if the dataset is a local dataset."""
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
......@@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
def _load_data_file(file_path: str) -> list[Any]:
with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
......@@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
r"""
Gets the preview samples from the dataset.
"""
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
r"""Get the preview samples from the dataset."""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
......@@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR
......@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
......
......@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Union
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
......@@ -35,7 +36,7 @@ if TYPE_CHECKING:
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
else:
......@@ -47,7 +48,7 @@ def save_model(
model_name: str,
model_path: str,
finetuning_type: str,
checkpoint_path: Union[str, List[str]],
checkpoint_path: Union[str, list[str]],
template: str,
export_size: int,
export_quantization_bit: str,
......@@ -106,7 +107,7 @@ def save_model(
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import is_multimodal
......@@ -29,12 +29,12 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
with gr.Row():
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface")
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
with gr.Row():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment