Commit 8293100a authored by luopl's avatar luopl
Browse files

update to 0.9.2.dev0

parent 2778a3d0
...@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
...@@ -48,6 +48,7 @@ def run_dpo( ...@@ -48,6 +48,7 @@ def run_dpo(
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(
template=template, template=template,
model=model,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module, **tokenizer_module,
...@@ -65,12 +66,6 @@ def run_dpo( ...@@ -65,12 +66,6 @@ def run_dpo(
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
model=model, model=model,
...@@ -86,13 +81,12 @@ def run_dpo( ...@@ -86,13 +81,12 @@ def run_dpo(
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second: if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens( train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"] dataset_module["train_dataset"], train_result.metrics, stage="rm"
) )
trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
......
...@@ -19,7 +19,7 @@ import warnings ...@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
...@@ -28,9 +28,9 @@ from trl.trainer import disable_dropout_in_model ...@@ -28,9 +28,9 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46 from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout: bool = True, disable_dropout: bool = True,
**kwargs, **kwargs,
): ):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout: if disable_dropout:
disable_dropout_in_model(model) disable_dropout_in_model(model)
if ref_model is not None: if ref_model is not None:
...@@ -77,6 +80,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -77,6 +80,7 @@ class CustomKTOTrainer(KTOTrainer):
self.ftx_gamma = finetuning_args.pref_ftx self.ftx_gamma = finetuning_args.pref_ftx
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
if not hasattr(self, "accelerator"): if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.") raise AttributeError("Please update `transformers`.")
...@@ -119,6 +123,9 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -119,6 +123,9 @@ class CustomKTOTrainer(KTOTrainer):
r""" r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
""" """
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override @override
...@@ -135,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -135,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r""" r"""
Runs forward pass and computes the log probabilities. Runs forward pass and computes the log probabilities.
""" """
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = nested_detach(batch, clone=True) # avoid error
model_inputs = { model_inputs = {
"input_ids": batch[f"{prefix}input_ids"], "input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"], "attention_mask": batch[f"{prefix}attention_mask"],
...@@ -245,17 +252,18 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -245,17 +252,18 @@ class CustomKTOTrainer(KTOTrainer):
return losses, metrics return losses, metrics
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
if return_outputs: if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else: else:
return loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
return loss return loss
......
...@@ -47,6 +47,7 @@ def run_kto( ...@@ -47,6 +47,7 @@ def run_kto(
data_collator = KTODataCollatorWithPadding( data_collator = KTODataCollatorWithPadding(
template=template, template=template,
model=model,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module, **tokenizer_module,
......
...@@ -46,7 +46,7 @@ def run_ppo( ...@@ -46,7 +46,7 @@ def run_ppo(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module) data_collator = MultiModalDataCollatorForSeq2Seq(template=template, model=model, **tokenizer_module)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
......
...@@ -13,19 +13,19 @@ ...@@ -13,19 +13,19 @@
# limitations under the License. # limitations under the License.
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.packages import is_transformers_version_equal_to_4_46 from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
import torch from transformers import PreTrainedModel, ProcessorMixin
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
...@@ -38,15 +38,15 @@ class CustomTrainer(Trainer): ...@@ -38,15 +38,15 @@ class CustomTrainer(Trainer):
def __init__( def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None: ) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
...@@ -67,17 +67,26 @@ class CustomTrainer(Trainer): ...@@ -67,17 +67,26 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
""" """
loss = super().compute_loss(model, inputs, return_outputs, **kwargs) loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs: if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else: else:
return loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
return loss return loss
...@@ -25,8 +25,8 @@ from transformers import Trainer ...@@ -25,8 +25,8 @@ from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46 from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
...@@ -48,7 +48,11 @@ class PairwiseTrainer(Trainer): ...@@ -48,7 +48,11 @@ class PairwiseTrainer(Trainer):
def __init__( def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None: ) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
self.add_callback(FixValueHeadModelCallback) self.add_callback(FixValueHeadModelCallback)
...@@ -56,9 +60,6 @@ class PairwiseTrainer(Trainer): ...@@ -56,9 +60,6 @@ class PairwiseTrainer(Trainer):
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
...@@ -78,6 +79,13 @@ class PairwiseTrainer(Trainer): ...@@ -78,6 +79,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
...@@ -100,8 +108,8 @@ class PairwiseTrainer(Trainer): ...@@ -100,8 +108,8 @@ class PairwiseTrainer(Trainer):
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0 loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
if return_outputs: if return_outputs:
return loss, (loss, chosen_scores, rejected_scores) return loss, (loss, chosen_scores, rejected_scores)
......
...@@ -44,7 +44,9 @@ def run_rm( ...@@ -44,7 +44,9 @@ def run_rm(
template = get_template_and_fix_tokenizer(tokenizer, data_args) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module) data_collator = PairwiseDataCollatorWithPadding(
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
)
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
......
...@@ -27,14 +27,14 @@ from typing_extensions import override ...@@ -27,14 +27,14 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46 from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
...@@ -51,15 +51,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -51,15 +51,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__( def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None: ) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
if processor is not None: if processor is not None:
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
...@@ -80,18 +82,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -80,18 +82,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Fixes the loss value for transformers 4.46.0. Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
""" """
loss = super().compute_loss(model, inputs, return_outputs, **kwargs) loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs: if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:]) loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else: else:
return loss / self.args.gradient_accumulation_steps loss = loss / self.args.gradient_accumulation_steps
return loss return loss
...@@ -102,41 +113,30 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -102,41 +113,30 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
inputs: Dict[str, Union["torch.Tensor", Any]], inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""
Removes the prompt part in the generated tokens. Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
labels = inputs["labels"] if "labels" in inputs else None if self.args.predict_with_generate: # do not pass labels to model when generate
if self.args.predict_with_generate: labels = inputs.pop("labels", None)
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." else:
labels = labels.detach().clone() if labels is not None else None # backup labels labels = inputs.get("labels")
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len: loss, generated_tokens, _ = super().prediction_step(
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
if generated_tokens is not None and self.args.predict_with_generate: if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous() generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor": def save_predictions(
r""" self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
Pads the tensor to the same length as the target tensor. ) -> None:
"""
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.
...@@ -149,24 +149,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -149,24 +149,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
logger.info_rank0(f"Saving prediction results to {output_prediction_file}") logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where( labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
) )
preds = np.where( preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
) )
for i in range(len(preds)): for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res)) with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
...@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens, get_logits_processor from ...extras.logging import get_logger
from ...extras.misc import calculate_tps, get_logits_processor
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push from ..trainer_utils import create_modelcard_and_push
...@@ -33,6 +34,9 @@ if TYPE_CHECKING: ...@@ -33,6 +34,9 @@ if TYPE_CHECKING:
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft( def run_sft(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
...@@ -52,6 +56,7 @@ def run_sft( ...@@ -52,6 +56,7 @@ def run_sft(
data_collator = SFTDataCollatorWith4DAttentionMask( data_collator = SFTDataCollatorWith4DAttentionMask(
template=template, template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn, block_diag_attn=model_args.block_diag_attn,
...@@ -65,11 +70,6 @@ def run_sft( ...@@ -65,11 +70,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset training_args.remove_unused_columns = False # important for multimodal dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["input_ids"])
# Metric utils # Metric utils
metric_module = {} metric_module = {}
if training_args.predict_with_generate: if training_args.predict_with_generate:
...@@ -91,7 +91,7 @@ def run_sft( ...@@ -91,7 +91,7 @@ def run_sft(
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor() gen_kwargs["logits_processor"] = get_logits_processor()
...@@ -99,12 +99,12 @@ def run_sft( ...@@ -99,12 +99,12 @@ def run_sft(
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second: if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens( train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"] dataset_module["train_dataset"], train_result.metrics, stage="sft"
) )
trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
...@@ -117,19 +117,16 @@ def run_sft( ...@@ -117,19 +117,16 @@ def run_sft(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Predict # Predict
if training_args.do_predict: if training_args.do_predict:
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results) trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card # Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
...@@ -30,20 +32,29 @@ from typing_extensions import override ...@@ -30,20 +32,29 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available(): if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -51,7 +62,7 @@ logger = logging.get_logger(__name__) ...@@ -51,7 +62,7 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer): class DummyOptimizer(torch.optim.Optimizer):
r""" r"""
A dummy optimizer used for the GaLore algorithm. A dummy optimizer used for the GaLore or APOLLO algorithm.
""" """
def __init__( def __init__(
...@@ -74,7 +85,7 @@ def create_modelcard_and_push( ...@@ -74,7 +85,7 @@ def create_modelcard_and_push(
trainer: "Trainer", trainer: "Trainer",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
kwargs = { kwargs = {
...@@ -187,7 +198,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: ...@@ -187,7 +198,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer( def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
...@@ -231,9 +242,10 @@ def _create_galore_optimizer( ...@@ -231,9 +242,10 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor": elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor optim_class = GaLoreAdafactor
else: else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}") raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.galore_layerwise: if finetuning_args.galore_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.") raise ValueError("Per-layer GaLore does not support gradient accumulation.")
...@@ -265,13 +277,100 @@ def _create_galore_optimizer( ...@@ -265,13 +277,100 @@ def _create_galore_optimizer(
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") logger.info_rank0(
f"Using GaLore optimizer with args: {galore_kwargs}. "
"It may cause hanging at the start of training, wait patiently."
)
return optimizer
def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
apollo_params.append(param)
apollo_kwargs = {
"rank": finetuning_args.apollo_rank,
"proj": finetuning_args.apollo_proj,
"proj_type": finetuning_args.apollo_proj_type,
"update_proj_gap": finetuning_args.apollo_update_interval,
"scale": finetuning_args.apollo_scale,
"scale_type": finetuning_args.apollo_scale_type,
"scale_front": finetuning_args.apollo_scale_front,
}
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_apollo_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.apollo_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # apollo params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
return optimizer return optimizer
def _create_loraplus_optimizer( def _create_loraplus_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate default_lr = training_args.learning_rate
...@@ -311,7 +410,7 @@ def _create_loraplus_optimizer( ...@@ -311,7 +410,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer( def _create_badam_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], [] decay_params, nodecay_params = [], []
...@@ -330,7 +429,7 @@ def _create_badam_optimizer( ...@@ -330,7 +429,7 @@ def _create_badam_optimizer(
] ]
if finetuning_args.badam_mode == "layer": if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs) base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer( optimizer = BlockOptimizer(
...@@ -350,7 +449,7 @@ def _create_badam_optimizer( ...@@ -350,7 +449,7 @@ def _create_badam_optimizer(
) )
elif finetuning_args.badam_mode == "ratio": elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6 assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio( optimizer = BlockOptimizerRatio(
...@@ -372,9 +471,9 @@ def _create_badam_optimizer( ...@@ -372,9 +471,9 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer( def _create_adam_mini_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None) hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None) num_q_head = getattr(model.config, "num_attention_heads", None)
...@@ -397,12 +496,15 @@ def _create_adam_mini_optimizer( ...@@ -397,12 +496,15 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer( def create_custom_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore: if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args) return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_apollo:
return _create_apollo_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None: if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args) return _create_loraplus_optimizer(model, training_args, finetuning_args)
...@@ -414,7 +516,7 @@ def create_custom_optimizer( ...@@ -414,7 +516,7 @@ def create_custom_optimizer(
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "TrainingArguments",
num_training_steps: int, num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None, optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None: ) -> None:
...@@ -457,3 +559,69 @@ def get_batch_logps( ...@@ -457,3 +559,69 @@ def get_batch_logps(
labels[labels == label_pad_token_id] = 0 # dummy token labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name,
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LlamaFactory"},
)
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer
...@@ -22,15 +22,21 @@ from transformers import PreTrainedModel ...@@ -22,15 +22,21 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo from .dpo import run_dpo
from .kto import run_kto from .kto import run_kto
from .ppo import run_ppo from .ppo import run_ppo
from .pt import run_pt from .pt import run_pt
from .rm import run_rm from .rm import run_rm
from .sft import run_sft from .sft import run_sft
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -40,10 +46,20 @@ if TYPE_CHECKING: ...@@ -40,10 +46,20 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: def _training_function(config: Dict[str, Any]) -> None:
callbacks.append(LogCallback()) args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback())
if finetuning_args.use_swanlab:
callbacks.append(get_swanlab_callback(finetuning_args))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt": if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft": elif finetuning_args.stage == "sft":
...@@ -60,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb ...@@ -60,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise ValueError(f"Unknown task: {finetuning_args.stage}.") raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args = read_args(args)
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None: def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args) model_args, data_args, finetuning_args, _ = get_infer_args(args)
......
...@@ -91,6 +91,7 @@ class WebChatModel(ChatModel): ...@@ -91,6 +91,7 @@ class WebChatModel(ChatModel):
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"), infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"), infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
) )
if checkpoint_path: if checkpoint_path:
...@@ -157,7 +158,7 @@ class WebChatModel(ChatModel): ...@@ -157,7 +158,7 @@ class WebChatModel(ChatModel):
result = response result = response
if isinstance(result, list): if isinstance(result, list):
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result] tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False) tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```" bot_text = "```json\n" + tool_calls + "\n```"
......
...@@ -84,6 +84,7 @@ def save_model( ...@@ -84,6 +84,7 @@ def save_model(
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_device=export_device, export_device=export_device,
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
trust_remote_code=True,
) )
if checkpoint_path: if checkpoint_path:
......
...@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ...@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
use_galore = gr.Checkbox() use_galore = gr.Checkbox()
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) galore_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) galore_scale = gr.Slider(minimum=0, maximum=100, value=2.0, step=0.1)
galore_target = gr.Textbox(value="all") galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
...@@ -250,6 +250,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ...@@ -250,6 +250,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
with gr.Accordion(open=False) as apollo_tab:
with gr.Row():
use_apollo = gr.Checkbox()
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
apollo_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
apollo_scale = gr.Slider(minimum=0, maximum=100, value=32.0, step=0.1)
apollo_target = gr.Textbox(value="all")
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
elem_dict.update(
dict(
apollo_tab=apollo_tab,
use_apollo=use_apollo,
apollo_rank=apollo_rank,
apollo_update_interval=apollo_update_interval,
apollo_scale=apollo_scale,
apollo_target=apollo_target,
)
)
with gr.Accordion(open=False) as badam_tab: with gr.Accordion(open=False) as badam_tab:
with gr.Row(): with gr.Row():
use_badam = gr.Checkbox() use_badam = gr.Checkbox()
...@@ -270,6 +290,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ...@@ -270,6 +290,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
with gr.Accordion(open=False) as swanlab_tab:
with gr.Row():
use_swanlab = gr.Checkbox()
swanlab_project = gr.Textbox(value="llamafactory")
swanlab_run_name = gr.Textbox()
swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
)
elem_dict.update(
dict(
swanlab_tab=swanlab_tab,
use_swanlab=use_swanlab,
swanlab_project=swanlab_project,
swanlab_run_name=swanlab_run_name,
swanlab_workspace=swanlab_workspace,
swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode,
)
)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button() arg_save_btn = gr.Button()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import platform
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import save_config from .common import save_config
...@@ -34,8 +35,9 @@ if is_gradio_available(): ...@@ -34,8 +35,9 @@ if is_gradio_available():
def create_ui(demo_mode: bool = False) -> "gr.Blocks": def create_ui(demo_mode: bool = False) -> "gr.Blocks":
engine = Engine(demo_mode=demo_mode, pure_chat=False) engine = Engine(demo_mode=demo_mode, pure_chat=False)
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
with gr.Blocks(title="LLaMA Board", css=CSS) as demo: with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
if demo_mode: if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>") gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML( gr.HTML(
......
...@@ -30,15 +30,19 @@ LOCALES = { ...@@ -30,15 +30,19 @@ LOCALES = {
"model_name": { "model_name": {
"en": { "en": {
"label": "Model name", "label": "Model name",
"info": "Input the name prefix to search for the model.",
}, },
"ru": { "ru": {
"label": "Название модели", "label": "Название модели",
"info": "Введите префикс имени для поиска модели.",
}, },
"zh": { "zh": {
"label": "模型名称", "label": "模型名称",
"info": "输入首单词以检索模型。",
}, },
"ko": { "ko": {
"label": "모델 이름", "label": "모델 이름",
"info": "모델을 검색하기 위해 이름 접두어를 입력하세요.",
}, },
}, },
"model_path": { "model_path": {
...@@ -464,7 +468,7 @@ LOCALES = { ...@@ -464,7 +468,7 @@ LOCALES = {
"val_size": { "val_size": {
"en": { "en": {
"label": "Val size", "label": "Val size",
"info": "Proportion of data in the dev set.", "info": "Percentage of validation set from the entire dataset.",
}, },
"ru": { "ru": {
"label": "Размер валидации", "label": "Размер валидации",
...@@ -1115,7 +1119,7 @@ LOCALES = { ...@@ -1115,7 +1119,7 @@ LOCALES = {
"info": "Нормализация оценок в тренировке PPO.", "info": "Нормализация оценок в тренировке PPO.",
}, },
"zh": { "zh": {
"label": "奖励模型", "label": "归一化分数",
"info": "PPO 训练中归一化奖励分数。", "info": "PPO 训练中归一化奖励分数。",
}, },
"ko": { "ko": {
...@@ -1158,19 +1162,19 @@ LOCALES = { ...@@ -1158,19 +1162,19 @@ LOCALES = {
"use_galore": { "use_galore": {
"en": { "en": {
"label": "Use GaLore", "label": "Use GaLore",
"info": "Enable gradient low-Rank projection.", "info": "Use GaLore optimizer.",
}, },
"ru": { "ru": {
"label": "Использовать GaLore", "label": "Использовать GaLore",
"info": "Включить проекцию градиента на низкоранговое пространство.", "info": "Используйте оптимизатор GaLore.",
}, },
"zh": { "zh": {
"label": "使用 GaLore", "label": "使用 GaLore",
"info": "使用梯度低秩投影。", "info": "使用 GaLore 优化器。",
}, },
"ko": { "ko": {
"label": "GaLore 사용", "label": "GaLore 사용",
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.", "info": "GaLore 최적화를 사용하세요.",
}, },
}, },
"galore_rank": { "galore_rank": {
...@@ -1245,6 +1249,110 @@ LOCALES = { ...@@ -1245,6 +1249,110 @@ LOCALES = {
"info": "GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.", "info": "GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
}, },
}, },
"apollo_tab": {
"en": {
"label": "APOLLO configurations",
},
"ru": {
"label": "Конфигурации APOLLO",
},
"zh": {
"label": "APOLLO 参数设置",
},
"ko": {
"label": "APOLLO 구성",
},
},
"use_apollo": {
"en": {
"label": "Use APOLLO",
"info": "Use APOLLO optimizer.",
},
"ru": {
"label": "Использовать APOLLO",
"info": "Используйте оптимизатор APOLLO.",
},
"zh": {
"label": "使用 APOLLO",
"info": "使用 APOLLO 优化器。",
},
"ko": {
"label": "APOLLO 사용",
"info": "APOLLO 최적화를 사용하세요.",
},
},
"apollo_rank": {
"en": {
"label": "APOLLO rank",
"info": "The rank of APOLLO gradients.",
},
"ru": {
"label": "Ранг APOLLO",
"info": "Ранг градиентов APOLLO.",
},
"zh": {
"label": "APOLLO 秩",
"info": "APOLLO 梯度的秩大小。",
},
"ko": {
"label": "APOLLO 랭크",
"info": "APOLLO 그레디언트의 랭크.",
},
},
"apollo_update_interval": {
"en": {
"label": "Update interval",
"info": "Number of steps to update the APOLLO projection.",
},
"ru": {
"label": "Интервал обновления",
"info": "Количество шагов для обновления проекции APOLLO.",
},
"zh": {
"label": "更新间隔",
"info": "相邻两次投影更新的步数。",
},
"ko": {
"label": "업데이트 간격",
"info": "APOLLO 프로젝션을 업데이트할 간격의 스텝 수.",
},
},
"apollo_scale": {
"en": {
"label": "APOLLO scale",
"info": "APOLLO scaling coefficient.",
},
"ru": {
"label": "LoRA Alpha",
"info": "Коэффициент масштабирования APOLLO.",
},
"zh": {
"label": "APOLLO 缩放系数",
"info": "APOLLO 缩放系数大小。",
},
"ko": {
"label": "APOLLO 스케일",
"info": "APOLLO 스케일링 계수.",
},
},
"apollo_target": {
"en": {
"label": "APOLLO modules",
"info": "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules.",
},
"ru": {
"label": "Модули APOLLO",
"info": "Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей.",
},
"zh": {
"label": "APOLLO 作用模块",
"info": "应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。",
},
"ko": {
"label": "APOLLO 모듈",
"info": "APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
},
},
"badam_tab": { "badam_tab": {
"en": { "en": {
"label": "BAdam configurations", "label": "BAdam configurations",
...@@ -1349,6 +1457,120 @@ LOCALES = { ...@@ -1349,6 +1457,120 @@ LOCALES = {
"info": "비율-BAdam의 업데이트 비율.", "info": "비율-BAdam의 업데이트 비율.",
}, },
}, },
"swanlab_tab": {
"en": {
"label": "SwanLab configurations",
},
"ru": {
"label": "Конфигурации SwanLab",
},
"zh": {
"label": "SwanLab 参数设置",
},
"ko": {
"label": "SwanLab 설정",
},
},
"use_swanlab": {
"en": {
"label": "Use SwanLab",
"info": "Enable SwanLab for experiment tracking and visualization.",
},
"ru": {
"label": "Использовать SwanLab",
"info": "Включить SwanLab для отслеживания и визуализации экспериментов.",
},
"zh": {
"label": "使用 SwanLab",
"info": "启用 SwanLab 进行实验跟踪和可视化。",
},
"ko": {
"label": "SwanLab 사용",
"info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
},
},
"swanlab_project": {
"en": {
"label": "SwanLab project",
},
"ru": {
"label": "SwanLab Проект",
},
"zh": {
"label": "SwanLab 项目名",
},
"ko": {
"label": "SwanLab 프로젝트",
},
},
"swanlab_run_name": {
"en": {
"label": "SwanLab experiment name (optional)",
},
"ru": {
"label": "SwanLab Имя эксперимента (опционально)",
},
"zh": {
"label": "SwanLab 实验名(非必填)",
},
"ko": {
"label": "SwanLab 실험 이름 (선택 사항)",
},
},
"swanlab_workspace": {
"en": {
"label": "SwanLab workspace (optional)",
"info": "Workspace for SwanLab. Defaults to the personal workspace.",
},
"ru": {
"label": "SwanLab Рабочая область (опционально)",
"info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
},
"zh": {
"label": "SwanLab 工作区(非必填)",
"info": "SwanLab 的工作区,默认在个人工作区下。",
},
"ko": {
"label": "SwanLab 작업 영역 (선택 사항)",
"info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
},
},
"swanlab_api_key": {
"en": {
"label": "SwanLab API key (optional)",
"info": "API key for SwanLab.",
},
"ru": {
"label": "SwanLab API ключ (опционально)",
"info": "API ключ для SwanLab.",
},
"zh": {
"label": "SwanLab API密钥(非必填)",
"info": "用于在编程环境登录 SwanLab,已登录则无需填写。",
},
"ko": {
"label": "SwanLab API 키 (선택 사항)",
"info": "SwanLab의 API 키.",
},
},
"swanlab_mode": {
"en": {
"label": "SwanLab mode",
"info": "Cloud or offline version.",
},
"ru": {
"label": "SwanLab Режим",
"info": "Версия в облаке или локальная версия.",
},
"zh": {
"label": "SwanLab 模式",
"info": "使用云端版或离线版 SwanLab。",
},
"ko": {
"label": "SwanLab 모드",
"info": "클라우드 버전 또는 오프라인 버전.",
},
},
"cmd_preview_btn": { "cmd_preview_btn": {
"en": { "en": {
"value": "Preview command", "value": "Preview command",
......
...@@ -19,9 +19,10 @@ from subprocess import Popen, TimeoutExpired ...@@ -19,9 +19,10 @@ from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES from .locales import ALERTS, LOCALES
...@@ -146,12 +147,15 @@ class Runner: ...@@ -146,12 +147,15 @@ class Runner:
shift_attn=get("train.shift_attn"), shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none", report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"), use_galore=get("train.use_galore"),
use_apollo=get("train.use_apollo"),
use_badam=get("train.use_badam"), use_badam=get("train.use_badam"),
use_swanlab=get("train.use_swanlab"),
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"), fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"), bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"), pure_bf16=(get("train.compute_type") == "pure_bf16"),
plot_loss=True, plot_loss=True,
trust_remote_code=True,
ddp_timeout=180000000, ddp_timeout=180000000,
include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME
) )
...@@ -170,6 +174,7 @@ class Runner: ...@@ -170,6 +174,7 @@ class Runner:
if get("top.quantization_bit") in QUANTIZATION_BITS: if get("top.quantization_bit") in QUANTIZATION_BITS:
args["quantization_bit"] = int(get("top.quantization_bit")) args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method") args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
# freeze config # freeze config
if args["finetuning_type"] == "freeze": if args["finetuning_type"] == "freeze":
...@@ -220,6 +225,13 @@ class Runner: ...@@ -220,6 +225,13 @@ class Runner:
args["galore_scale"] = get("train.galore_scale") args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target") args["galore_target"] = get("train.galore_target")
# apollo config
if args["use_apollo"]:
args["apollo_rank"] = get("train.apollo_rank")
args["apollo_update_interval"] = get("train.apollo_update_interval")
args["apollo_scale"] = get("train.apollo_scale")
args["apollo_target"] = get("train.apollo_target")
# badam config # badam config
if args["use_badam"]: if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode") args["badam_mode"] = get("train.badam_mode")
...@@ -227,6 +239,14 @@ class Runner: ...@@ -227,6 +239,14 @@ class Runner:
args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio") args["badam_update_ratio"] = get("train.badam_update_ratio")
# swanlab config
if get("train.use_swanlab"):
args["swanlab_project"] = get("train.swanlab_project")
args["swanlab_run_name"] = get("train.swanlab_run_name")
args["swanlab_workspace"] = get("train.swanlab_workspace")
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_mode"] = get("train.swanlab_mode")
# eval config # eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo": if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size") args["val_size"] = get("train.val_size")
...@@ -268,6 +288,7 @@ class Runner: ...@@ -268,6 +288,7 @@ class Runner:
top_p=get("eval.top_p"), top_p=get("eval.top_p"),
temperature=get("eval.temperature"), temperature=get("eval.temperature"),
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
trust_remote_code=True,
) )
if get("eval.predict"): if get("eval.predict"):
...@@ -383,12 +404,12 @@ class Runner: ...@@ -383,12 +404,12 @@ class Runner:
continue continue
if self.do_train: if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)): if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang] finish_info = ALERTS["info_finished"][lang]
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
else: else:
if os.path.exists(os.path.join(output_path, "all_results.json")): if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
......
...@@ -12,9 +12,105 @@ ...@@ -12,9 +12,105 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
from PIL import Image
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
from llamafactory.data.collator import prepare_4d_attention_mask def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
p = tokenizer_module["tokenizer"].pad_token_id
q = IGNORE_INDEX
features = [
{
"input_ids": [0, 1, 2, 3, 4, 5],
"attention_mask": [1, 1, 1, 1, 1, 1],
"labels": [q, q, 2, 3, 4, 5],
},
{
"input_ids": [6, 7],
"attention_mask": [1, 1],
"labels": [q, 7],
},
]
batch_input = data_collator(features)
expected_input = {
"input_ids": [
[0, 1, 2, 3, 4, 5, p, p],
[6, 7, p, p, p, p, p, p],
],
"attention_mask": [
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
],
"labels": [
[q, q, 2, 3, 4, 5, q, q],
[q, 7, q, q, q, q, q, q],
],
}
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
p = tokenizer_module["tokenizer"].pad_token_id
q = IGNORE_INDEX
s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>")
e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>")
m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>")
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
features = [
{
"input_ids": [0, 1, 2, 3],
"attention_mask": [1, 1, 1, 1],
"labels": [0, 1, 2, 3],
},
]
batch_input = data_collator(features)
expected_input = {
"input_ids": [
[0, 1, 2, 3, s, m, m, m, m, e, p, p],
],
"attention_mask": [
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
],
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
**tokenizer_module["processor"].image_processor(fake_image),
}
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def test_4d_attention_mask(): def test_4d_attention_mask():
......
...@@ -13,10 +13,29 @@ ...@@ -13,10 +13,29 @@
# limitations under the License. # limitations under the License.
import json import json
from datetime import datetime
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
TOOLS = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
def test_empty_formatter(): def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"]) formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"] assert formatter.apply() == ["\n"]
...@@ -28,39 +47,27 @@ def test_string_formatter(): ...@@ -28,39 +47,27 @@ def test_string_formatter():
def test_function_formatter(): def test_function_formatter():
formatter = FunctionFormatter(slots=[], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""" """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
] ]
def test_multi_function_formatter(): def test_multi_function_formatter():
formatter = FunctionFormatter(slots=[], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2) tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
] ]
def test_default_tool_formatter(): def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default") formatter = ToolFormatter(tool_format="default")
tools = [ assert formatter.apply(content=json.dumps(TOOLS)) == [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
"You have access to the following tools:\n" "You have access to the following tools:\n"
"> Tool Name: test_tool\n" "> Tool Name: test_tool\n"
"Tool Description: tool_desc\n" "Tool Description: tool_desc\n"
...@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor(): ...@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
] ]
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
def test_glm4_tool_formatter(): def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4") formatter = ToolFormatter(tool_format="glm4")
tools = [ assert formatter.apply(content=json.dumps(TOOLS)) == [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
"## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4)) f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
] ]
...@@ -121,3 +120,127 @@ def test_glm4_tool_extractor(): ...@@ -121,3 +120,127 @@ def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4") formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n""" result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
"<|eot_id|>",
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
]
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
""""arguments": <args-json-object>}\n</tool_call><|im_end|>\n"""
]
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Sequence
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.hparams import ModelArguments from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
...@@ -29,6 +29,7 @@ if TYPE_CHECKING: ...@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin from llamafactory.data.mm_plugin import BasePlugin
from llamafactory.model.loader import TokenizerModule
HF_TOKEN = os.getenv("HF_TOKEN") HF_TOKEN = os.getenv("HF_TOKEN")
...@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: ...@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a[key] == batch_b[key] assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]: def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
model_args = ModelArguments(model_name_or_path=model_name_or_path) model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
tokenizer_module = load_tokenizer(model_args) return load_tokenizer(model_args)
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def _check_plugin( def _check_plugin(
...@@ -121,73 +121,75 @@ def _check_plugin( ...@@ -121,73 +121,75 @@ def _check_plugin(
def test_base_plugin(): def test_base_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>") base_plugin = get_mm_plugin(name="base", image_token="<image>")
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor} check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_llava_plugin(): def test_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
image_seqlen = 576 image_seqlen = 576
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor} tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()} {key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_llava_next_plugin(): def test_llava_next_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176 image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()} {key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_llava_next_video_plugin(): def test_llava_next_video_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176 image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()} {key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin(): def test_paligemma_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
image_seqlen = 256 image_seqlen = 256
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor} tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES {key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
] ]
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS check_inputs["expected_input_ids"] = [
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
] * image_seqlen + INPUT_IDS
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)] check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]} check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_pixtral_plugin(): def test_pixtral_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
image_slice_height, image_slice_width = 2, 2 image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{ {
key: value.replace( key: value.replace(
...@@ -199,17 +201,17 @@ def test_pixtral_plugin(): ...@@ -199,17 +201,17 @@ def test_pixtral_plugin():
} }
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("image_sizes") check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0] check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_qwen2_vl_plugin(): def test_qwen2_vl_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
image_seqlen = 4 image_seqlen = 4
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor} tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{ {
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen)) key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
...@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin(): ...@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
} }
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_video_llava_plugin(): def test_video_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 256 image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()} {key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment