Commit 8293100a authored by luopl's avatar luopl
Browse files

update to 0.9.2.dev0

parent 2778a3d0
......@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens
from ...extras.misc import calculate_tps
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
......@@ -48,6 +48,7 @@ def run_dpo(
data_collator = PairwiseDataCollatorWithPadding(
template=template,
model=model,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
......@@ -65,12 +66,6 @@ def run_dpo(
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
# Initialize our Trainer
trainer = CustomDPOTrainer(
model=model,
......@@ -86,13 +81,12 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="rm"
)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
......
......@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
from transformers import Trainer
......@@ -28,9 +28,9 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
......@@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout: bool = True,
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
......@@ -77,6 +80,7 @@ class CustomKTOTrainer(KTOTrainer):
self.ftx_gamma = finetuning_args.pref_ftx
Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
......@@ -119,6 +123,9 @@ class CustomKTOTrainer(KTOTrainer):
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self)
@override
......@@ -135,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r"""
Runs forward pass and computes the log probabilities.
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
model_inputs = {
"input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"],
......@@ -245,17 +252,18 @@ class CustomKTOTrainer(KTOTrainer):
return losses, metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps
return loss
......
......@@ -47,6 +47,7 @@ def run_kto(
data_collator = KTODataCollatorWithPadding(
template=template,
model=model,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
......
......@@ -46,7 +46,7 @@ def run_ppo(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, model=model, **tokenizer_module)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
......
......@@ -13,19 +13,19 @@
# limitations under the License.
from types import MethodType
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
import torch
from transformers import ProcessorMixin
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
......@@ -38,15 +38,15 @@ class CustomTrainer(Trainer):
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
......@@ -67,17 +67,26 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps
return loss
......@@ -25,8 +25,8 @@ from transformers import Trainer
from typing_extensions import override
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -48,7 +48,11 @@ class PairwiseTrainer(Trainer):
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
super().__init__(**kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
self.add_callback(FixValueHeadModelCallback)
......@@ -56,9 +60,6 @@ class PairwiseTrainer(Trainer):
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
......@@ -78,6 +79,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
......@@ -100,8 +108,8 @@ class PairwiseTrainer(Trainer):
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
......
......@@ -44,7 +44,9 @@ def run_rm(
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module)
data_collator = PairwiseDataCollatorWithPadding(
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
)
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
......
......@@ -27,14 +27,14 @@ from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import ProcessorMixin
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
......@@ -51,15 +51,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer")
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
......@@ -80,18 +82,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps
return loss
......@@ -102,41 +113,30 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
labels = inputs["labels"] if "labels" in inputs else None
if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
labels = labels.detach().clone() if labels is not None else None # backup labels
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
inputs["labels"] = inputs["labels"][:, :prompt_len]
loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
if self.args.predict_with_generate: # do not pass labels to model when generate
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)
if generated_tokens is not None and self.args.predict_with_generate:
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id
generated_tokens = generated_tokens.contiguous()
return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
r"""
Pads the tensor to the same length as the target tensor.
"""
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory
def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
r"""
Saves model predictions to `output_dir`.
......@@ -149,24 +149,23 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id
)
preds = np.where(
predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id
predict_results.predictions != IGNORE_INDEX,
predict_results.predictions,
self.processing_class.pad_token_id,
)
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0]
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
writer.write("\n".join(res))
with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
......@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens, get_logits_processor
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps, get_logits_processor
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
......@@ -33,6 +34,9 @@ if TYPE_CHECKING:
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
......@@ -52,6 +56,7 @@ def run_sft(
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
......@@ -65,11 +70,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False # important for multimodal dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["input_ids"])
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
......@@ -91,7 +91,7 @@ def run_sft(
)
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
......@@ -99,12 +99,12 @@ def run_sft(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
......@@ -117,19 +117,16 @@ def run_sft(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
......@@ -17,7 +17,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
......@@ -30,20 +32,29 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import IGNORE_INDEX
from ..extras.packages import is_galore_available
from ..extras.packages import is_apollo_available, is_galore_available, is_ray_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if is_apollo_available():
from apollo_torch import APOLLOAdamW # type: ignore
if is_ray_available():
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from transformers import PreTrainedModel, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
from ..hparams import DataArguments, RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
......@@ -51,7 +62,7 @@ logger = logging.get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
r"""
A dummy optimizer used for the GaLore algorithm.
A dummy optimizer used for the GaLore or APOLLO algorithm.
"""
def __init__(
......@@ -74,7 +85,7 @@ def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
kwargs = {
......@@ -187,7 +198,7 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
def _create_galore_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
......@@ -231,9 +242,10 @@ def _create_galore_optimizer(
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.galore_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise GaLore.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
......@@ -265,13 +277,100 @@ def _create_galore_optimizer(
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
logger.info_rank0(
f"Using GaLore optimizer with args: {galore_kwargs}. "
"It may cause hanging at the start of training, wait patiently."
)
return optimizer
def _create_apollo_optimizer(
model: "PreTrainedModel",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
apollo_targets = finetuning_args.apollo_target
apollo_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
apollo_params.append(param)
apollo_kwargs = {
"rank": finetuning_args.apollo_rank,
"proj": finetuning_args.apollo_proj,
"proj_type": finetuning_args.apollo_proj_type,
"update_proj_gap": finetuning_args.apollo_update_interval,
"scale": finetuning_args.apollo_scale,
"scale_type": finetuning_args.apollo_scale_type,
"scale_front": finetuning_args.apollo_scale_front,
}
id_apollo_params = {id(param) for param in apollo_params}
decay_params, nodecay_params = [], [] # they are non-apollo parameters
trainable_params: List["torch.nn.Parameter"] = [] # apollo_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_apollo_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = APOLLOAdamW
else:
raise NotImplementedError(f"Unknown optim: {training_args.optim}.")
if finetuning_args.apollo_layerwise:
logger.warning_rank0("The displayed gradient norm will be all zeros in layerwise APOLLO.")
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in nodecay_params:
param_groups = [dict(params=[param], weight_decay=0.0)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in apollo_params: # apollo params have weight decay
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info_rank0(f"Using APOLLO optimizer with args: {apollo_kwargs}.")
return optimizer
def _create_loraplus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
default_lr = training_args.learning_rate
......@@ -311,7 +410,7 @@ def _create_loraplus_optimizer(
def _create_badam_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
decay_params, nodecay_params = [], []
......@@ -330,7 +429,7 @@ def _create_badam_optimizer(
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
......@@ -350,7 +449,7 @@ def _create_badam_optimizer(
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
......@@ -372,9 +471,9 @@ def _create_badam_optimizer(
def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
......@@ -397,12 +496,15 @@ def _create_adam_mini_optimizer(
def create_custom_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_apollo:
return _create_apollo_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args)
......@@ -414,7 +516,7 @@ def create_custom_optimizer(
def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments",
training_args: "TrainingArguments",
num_training_steps: int,
optimizer: Optional["torch.optim.Optimizer"] = None,
) -> None:
......@@ -457,3 +559,69 @@ def get_batch_logps(
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)
swanlab_callback = SwanLabCallback(
project=finetuning_args.swanlab_project,
workspace=finetuning_args.swanlab_workspace,
experiment_name=finetuning_args.swanlab_run_name,
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LlamaFactory"},
)
return swanlab_callback
def get_ray_trainer(
training_function: Callable,
train_loop_config: Dict[str, Any],
ray_args: "RayArguments",
) -> "TorchTrainer":
if not ray_args.use_ray:
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
trainer = TorchTrainer(
training_function,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=ray_args.ray_num_workers,
resources_per_worker=ray_args.resources_per_worker,
placement_strategy=ray_args.placement_strategy,
use_gpu=True,
),
run_config=RunConfig(
name=ray_args.ray_run_name,
storage_path=Path("./saves").absolute().as_posix(),
),
)
return trainer
......@@ -22,15 +22,21 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..hparams import get_infer_args, get_train_args
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
from .dpo import run_dpo
from .kto import run_kto
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .trainer_utils import get_ray_trainer, get_swanlab_callback
if is_ray_available():
from ray.train.huggingface.transformers import RayTrainReportCallback
if TYPE_CHECKING:
......@@ -40,10 +46,20 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
callbacks.append(LogCallback())
def _training_function(config: Dict[str, Any]) -> None:
args = config.get("args")
callbacks: List[Any] = config.get("callbacks")
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks.append(LogCallback())
if finetuning_args.pissa_convert:
callbacks.append(PissaConvertCallback())
if finetuning_args.use_swanlab:
callbacks.append(get_swanlab_callback(finetuning_args))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
......@@ -60,6 +76,22 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
raise ValueError(f"Unknown task: {finetuning_args.stage}.")
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None:
args = read_args(args)
ray_args = get_ray_args(args)
callbacks = callbacks or []
if ray_args.use_ray:
callbacks.append(RayTrainReportCallback())
trainer = get_ray_trainer(
training_function=_training_function,
train_loop_config={"args": args, "callbacks": callbacks},
ray_args=ray_args,
)
trainer.fit()
else:
_training_function(config={"args": args, "callbacks": callbacks})
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, _ = get_infer_args(args)
......
......@@ -91,6 +91,7 @@ class WebChatModel(ChatModel):
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
trust_remote_code=True,
)
if checkpoint_path:
......@@ -157,7 +158,7 @@ class WebChatModel(ChatModel):
result = response
if isinstance(result, list):
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
......
......@@ -84,6 +84,7 @@ def save_model(
export_quantization_dataset=export_quantization_dataset,
export_device=export_device,
export_legacy_format=export_legacy_format,
trust_remote_code=True,
)
if checkpoint_path:
......
......@@ -234,8 +234,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row():
use_galore = gr.Checkbox()
galore_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
galore_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
galore_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
galore_scale = gr.Slider(minimum=0, maximum=100, value=2.0, step=0.1)
galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
......@@ -250,6 +250,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
with gr.Accordion(open=False) as apollo_tab:
with gr.Row():
use_apollo = gr.Checkbox()
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
apollo_update_interval = gr.Slider(minimum=1, maximum=2048, value=200, step=1)
apollo_scale = gr.Slider(minimum=0, maximum=100, value=32.0, step=0.1)
apollo_target = gr.Textbox(value="all")
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
elem_dict.update(
dict(
apollo_tab=apollo_tab,
use_apollo=use_apollo,
apollo_rank=apollo_rank,
apollo_update_interval=apollo_update_interval,
apollo_scale=apollo_scale,
apollo_target=apollo_target,
)
)
with gr.Accordion(open=False) as badam_tab:
with gr.Row():
use_badam = gr.Checkbox()
......@@ -270,6 +290,30 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
)
)
with gr.Accordion(open=False) as swanlab_tab:
with gr.Row():
use_swanlab = gr.Checkbox()
swanlab_project = gr.Textbox(value="llamafactory")
swanlab_run_name = gr.Textbox()
swanlab_workspace = gr.Textbox()
swanlab_api_key = gr.Textbox()
swanlab_mode = gr.Dropdown(choices=["cloud", "local"], value="cloud")
input_elems.update(
{use_swanlab, swanlab_project, swanlab_run_name, swanlab_workspace, swanlab_api_key, swanlab_mode}
)
elem_dict.update(
dict(
swanlab_tab=swanlab_tab,
use_swanlab=use_swanlab,
swanlab_project=swanlab_project,
swanlab_run_name=swanlab_run_name,
swanlab_workspace=swanlab_workspace,
swanlab_api_key=swanlab_api_key,
swanlab_mode=swanlab_mode,
)
)
with gr.Row():
cmd_preview_btn = gr.Button()
arg_save_btn = gr.Button()
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import platform
from ..extras.packages import is_gradio_available
from .common import save_config
......@@ -34,8 +35,9 @@ if is_gradio_available():
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
engine = Engine(demo_mode=demo_mode, pure_chat=False)
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML(
......
......@@ -30,15 +30,19 @@ LOCALES = {
"model_name": {
"en": {
"label": "Model name",
"info": "Input the name prefix to search for the model.",
},
"ru": {
"label": "Название модели",
"info": "Введите префикс имени для поиска модели.",
},
"zh": {
"label": "模型名称",
"info": "输入首单词以检索模型。",
},
"ko": {
"label": "모델 이름",
"info": "모델을 검색하기 위해 이름 접두어를 입력하세요.",
},
},
"model_path": {
......@@ -464,7 +468,7 @@ LOCALES = {
"val_size": {
"en": {
"label": "Val size",
"info": "Proportion of data in the dev set.",
"info": "Percentage of validation set from the entire dataset.",
},
"ru": {
"label": "Размер валидации",
......@@ -1115,7 +1119,7 @@ LOCALES = {
"info": "Нормализация оценок в тренировке PPO.",
},
"zh": {
"label": "奖励模型",
"label": "归一化分数",
"info": "PPO 训练中归一化奖励分数。",
},
"ko": {
......@@ -1158,19 +1162,19 @@ LOCALES = {
"use_galore": {
"en": {
"label": "Use GaLore",
"info": "Enable gradient low-Rank projection.",
"info": "Use GaLore optimizer.",
},
"ru": {
"label": "Использовать GaLore",
"info": "Включить проекцию градиента на низкоранговое пространство.",
"info": "Используйте оптимизатор GaLore.",
},
"zh": {
"label": "使用 GaLore",
"info": "使用梯度低秩投影。",
"info": "使用 GaLore 优化器。",
},
"ko": {
"label": "GaLore 사용",
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
"info": "GaLore 최적화를 사용하세요.",
},
},
"galore_rank": {
......@@ -1245,6 +1249,110 @@ LOCALES = {
"info": "GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
},
},
"apollo_tab": {
"en": {
"label": "APOLLO configurations",
},
"ru": {
"label": "Конфигурации APOLLO",
},
"zh": {
"label": "APOLLO 参数设置",
},
"ko": {
"label": "APOLLO 구성",
},
},
"use_apollo": {
"en": {
"label": "Use APOLLO",
"info": "Use APOLLO optimizer.",
},
"ru": {
"label": "Использовать APOLLO",
"info": "Используйте оптимизатор APOLLO.",
},
"zh": {
"label": "使用 APOLLO",
"info": "使用 APOLLO 优化器。",
},
"ko": {
"label": "APOLLO 사용",
"info": "APOLLO 최적화를 사용하세요.",
},
},
"apollo_rank": {
"en": {
"label": "APOLLO rank",
"info": "The rank of APOLLO gradients.",
},
"ru": {
"label": "Ранг APOLLO",
"info": "Ранг градиентов APOLLO.",
},
"zh": {
"label": "APOLLO 秩",
"info": "APOLLO 梯度的秩大小。",
},
"ko": {
"label": "APOLLO 랭크",
"info": "APOLLO 그레디언트의 랭크.",
},
},
"apollo_update_interval": {
"en": {
"label": "Update interval",
"info": "Number of steps to update the APOLLO projection.",
},
"ru": {
"label": "Интервал обновления",
"info": "Количество шагов для обновления проекции APOLLO.",
},
"zh": {
"label": "更新间隔",
"info": "相邻两次投影更新的步数。",
},
"ko": {
"label": "업데이트 간격",
"info": "APOLLO 프로젝션을 업데이트할 간격의 스텝 수.",
},
},
"apollo_scale": {
"en": {
"label": "APOLLO scale",
"info": "APOLLO scaling coefficient.",
},
"ru": {
"label": "LoRA Alpha",
"info": "Коэффициент масштабирования APOLLO.",
},
"zh": {
"label": "APOLLO 缩放系数",
"info": "APOLLO 缩放系数大小。",
},
"ko": {
"label": "APOLLO 스케일",
"info": "APOLLO 스케일링 계수.",
},
},
"apollo_target": {
"en": {
"label": "APOLLO modules",
"info": "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules.",
},
"ru": {
"label": "Модули APOLLO",
"info": "Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей.",
},
"zh": {
"label": "APOLLO 作用模块",
"info": "应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。",
},
"ko": {
"label": "APOLLO 모듈",
"info": "APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
},
},
"badam_tab": {
"en": {
"label": "BAdam configurations",
......@@ -1349,6 +1457,120 @@ LOCALES = {
"info": "비율-BAdam의 업데이트 비율.",
},
},
"swanlab_tab": {
"en": {
"label": "SwanLab configurations",
},
"ru": {
"label": "Конфигурации SwanLab",
},
"zh": {
"label": "SwanLab 参数设置",
},
"ko": {
"label": "SwanLab 설정",
},
},
"use_swanlab": {
"en": {
"label": "Use SwanLab",
"info": "Enable SwanLab for experiment tracking and visualization.",
},
"ru": {
"label": "Использовать SwanLab",
"info": "Включить SwanLab для отслеживания и визуализации экспериментов.",
},
"zh": {
"label": "使用 SwanLab",
"info": "启用 SwanLab 进行实验跟踪和可视化。",
},
"ko": {
"label": "SwanLab 사용",
"info": "SwanLab를 사용하여 실험을 추적하고 시각화합니다.",
},
},
"swanlab_project": {
"en": {
"label": "SwanLab project",
},
"ru": {
"label": "SwanLab Проект",
},
"zh": {
"label": "SwanLab 项目名",
},
"ko": {
"label": "SwanLab 프로젝트",
},
},
"swanlab_run_name": {
"en": {
"label": "SwanLab experiment name (optional)",
},
"ru": {
"label": "SwanLab Имя эксперимента (опционально)",
},
"zh": {
"label": "SwanLab 实验名(非必填)",
},
"ko": {
"label": "SwanLab 실험 이름 (선택 사항)",
},
},
"swanlab_workspace": {
"en": {
"label": "SwanLab workspace (optional)",
"info": "Workspace for SwanLab. Defaults to the personal workspace.",
},
"ru": {
"label": "SwanLab Рабочая область (опционально)",
"info": "Рабочая область SwanLab, если не заполнено, то по умолчанию в личной рабочей области.",
},
"zh": {
"label": "SwanLab 工作区(非必填)",
"info": "SwanLab 的工作区,默认在个人工作区下。",
},
"ko": {
"label": "SwanLab 작업 영역 (선택 사항)",
"info": "SwanLab 조직의 작업 영역, 비어 있으면 기본적으로 개인 작업 영역에 있습니다.",
},
},
"swanlab_api_key": {
"en": {
"label": "SwanLab API key (optional)",
"info": "API key for SwanLab.",
},
"ru": {
"label": "SwanLab API ключ (опционально)",
"info": "API ключ для SwanLab.",
},
"zh": {
"label": "SwanLab API密钥(非必填)",
"info": "用于在编程环境登录 SwanLab,已登录则无需填写。",
},
"ko": {
"label": "SwanLab API 키 (선택 사항)",
"info": "SwanLab의 API 키.",
},
},
"swanlab_mode": {
"en": {
"label": "SwanLab mode",
"info": "Cloud or offline version.",
},
"ru": {
"label": "SwanLab Режим",
"info": "Версия в облаке или локальная версия.",
},
"zh": {
"label": "SwanLab 模式",
"info": "使用云端版或离线版 SwanLab。",
},
"ko": {
"label": "SwanLab 모드",
"info": "클라우드 버전 또는 오프라인 버전.",
},
},
"cmd_preview_btn": {
"en": {
"value": "Preview command",
......
......@@ -19,9 +19,10 @@ from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES
......@@ -146,12 +147,15 @@ class Runner:
shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_apollo=get("train.use_apollo"),
use_badam=get("train.use_badam"),
use_swanlab=get("train.use_swanlab"),
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
plot_loss=True,
trust_remote_code=True,
ddp_timeout=180000000,
include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME
)
......@@ -170,6 +174,7 @@ class Runner:
if get("top.quantization_bit") in QUANTIZATION_BITS:
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
# freeze config
if args["finetuning_type"] == "freeze":
......@@ -220,6 +225,13 @@ class Runner:
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
# apollo config
if args["use_apollo"]:
args["apollo_rank"] = get("train.apollo_rank")
args["apollo_update_interval"] = get("train.apollo_update_interval")
args["apollo_scale"] = get("train.apollo_scale")
args["apollo_target"] = get("train.apollo_target")
# badam config
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
......@@ -227,6 +239,14 @@ class Runner:
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio")
# swanlab config
if get("train.use_swanlab"):
args["swanlab_project"] = get("train.swanlab_project")
args["swanlab_run_name"] = get("train.swanlab_run_name")
args["swanlab_workspace"] = get("train.swanlab_workspace")
args["swanlab_api_key"] = get("train.swanlab_api_key")
args["swanlab_mode"] = get("train.swanlab_mode")
# eval config
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size")
......@@ -268,6 +288,7 @@ class Runner:
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
trust_remote_code=True,
)
if get("eval.predict"):
......@@ -383,12 +404,12 @@ class Runner:
continue
if self.do_train:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_path, "all_results.json")):
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
......
......@@ -12,9 +12,105 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from PIL import Image
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
from llamafactory.data.collator import prepare_4d_attention_mask
def test_base_collator():
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
p = tokenizer_module["tokenizer"].pad_token_id
q = IGNORE_INDEX
features = [
{
"input_ids": [0, 1, 2, 3, 4, 5],
"attention_mask": [1, 1, 1, 1, 1, 1],
"labels": [q, q, 2, 3, 4, 5],
},
{
"input_ids": [6, 7],
"attention_mask": [1, 1],
"labels": [q, 7],
},
]
batch_input = data_collator(features)
expected_input = {
"input_ids": [
[0, 1, 2, 3, 4, 5, p, p],
[6, 7, p, p, p, p, p, p],
],
"attention_mask": [
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
],
"labels": [
[q, q, 2, 3, 4, 5, q, q],
[q, 7, q, q, q, q, q, q],
],
}
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
p = tokenizer_module["tokenizer"].pad_token_id
q = IGNORE_INDEX
s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>")
e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>")
m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>")
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
features = [
{
"input_ids": [0, 1, 2, 3],
"attention_mask": [1, 1, 1, 1],
"labels": [0, 1, 2, 3],
},
]
batch_input = data_collator(features)
expected_input = {
"input_ids": [
[0, 1, 2, 3, s, m, m, m, m, e, p, p],
],
"attention_mask": [
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
],
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
**tokenizer_module["processor"].image_processor(fake_image),
}
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
def test_4d_attention_mask():
......
......@@ -13,10 +13,29 @@
# limitations under the License.
import json
from datetime import datetime
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
TOOLS = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
......@@ -28,39 +47,27 @@ def test_string_formatter():
def test_function_formatter():
formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=[], tool_format="default")
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
"</s>",
]
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
assert formatter.apply(content=json.dumps(TOOLS)) == [
"You have access to the following tools:\n"
"> Tool Name: test_tool\n"
"Tool Description: tool_desc\n"
......@@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
]
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
"## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4))
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
......@@ -121,3 +120,127 @@ def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
"<|eot_id|>",
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
]
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
""""arguments": <args-json-object>}\n</tool_call><|im_end|>\n"""
]
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
......@@ -13,14 +13,14 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
import pytest
import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.hparams import ModelArguments
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
......@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin
from llamafactory.model.loader import TokenizerModule
HF_TOKEN = os.getenv("HF_TOKEN")
......@@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
model_args = ModelArguments(model_name_or_path=model_name_or_path)
tokenizer_module = load_tokenizer(model_args)
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
return load_tokenizer(model_args)
def _check_plugin(
......@@ -121,73 +121,75 @@ def _check_plugin(
def test_base_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)
def test_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
image_seqlen = 576
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_video_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
image_seqlen = 256
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
check_inputs["expected_input_ids"] = [
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
] * image_seqlen + INPUT_IDS
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs)
def test_pixtral_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
......@@ -199,17 +201,17 @@ def test_pixtral_plugin():
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
image_seqlen = 4
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
......@@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_video_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment