Commit c3a3c678 authored by chenych's avatar chenych
Browse files

Update to v0.9.3

parent 1bc2def5
...@@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
padding_side="right", padding_side="right",
**init_kwargs, **init_kwargs,
) )
except ValueError: # try the fast one except ValueError: # try another one
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
use_fast=True, use_fast=not model_args.use_fast_tokenizer,
padding_side="right", padding_side="right",
**init_kwargs, **init_kwargs,
) )
...@@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
raise OSError("Failed to load tokenizer.") from e raise OSError("Failed to load tokenizer.") from e
patch_tokenizer(tokenizer, model_args) patch_tokenizer(tokenizer, model_args)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(
patch_processor(processor, tokenizer, model_args) model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
**init_kwargs,
)
except ValueError: # try another one
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=not model_args.use_fast_tokenizer,
**init_kwargs,
)
except Exception as e: except Exception as e:
logger.info_rank0(f"Failed to load processor: {e}.") raise OSError("Failed to load processor.") from e
processor = None
patch_processor(processor, tokenizer, model_args)
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
...@@ -138,7 +149,7 @@ def load_model( ...@@ -138,7 +149,7 @@ def load_model(
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
lazy_load = True lazy_load = True
elif is_trainable: elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args) model = load_unsloth_pretrained_model(config, model_args, finetuning_args)
if model is None and not lazy_load: if model is None and not lazy_load:
init_kwargs["config"] = config init_kwargs["config"] = config
......
...@@ -73,6 +73,8 @@ def apply_liger_kernel( ...@@ -73,6 +73,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3": elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
elif model_type == "qwen3_moe":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
...@@ -21,14 +21,17 @@ from ...extras.misc import get_current_device ...@@ -21,14 +21,17 @@ from ...extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments from ...hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def _get_unsloth_kwargs( def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"model_name": model_name_or_path, "model_name": model_name_or_path,
...@@ -36,6 +39,7 @@ def _get_unsloth_kwargs( ...@@ -36,6 +39,7 @@ def _get_unsloth_kwargs(
"dtype": model_args.compute_dtype, "dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4, "load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"full_finetuning": finetuning_args.finetuning_type == "full",
"device_map": {"": get_current_device()}, "device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None), "rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False, "fix_tokenizer": False,
...@@ -45,12 +49,12 @@ def _get_unsloth_kwargs( ...@@ -45,12 +49,12 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model( def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments" config: "PretrainedConfig", model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
) -> Optional["PreTrainedModel"]: ) -> Optional["PreTrainedModel"]:
r"""Optionally load pretrained model with unsloth. Used in training.""" r"""Optionally load pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args, finetuning_args)
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError:
......
...@@ -76,7 +76,7 @@ def _register_composite_model( ...@@ -76,7 +76,7 @@ def _register_composite_model(
model_type=model_type, model_type=model_type,
projector_key=projector_key or "multi_modal_projector", projector_key=projector_key or "multi_modal_projector",
vision_model_keys=vision_model_keys or ["vision_tower"], vision_model_keys=vision_model_keys or ["vision_tower"],
language_model_keys=language_model_keys or ["language_model"], language_model_keys=language_model_keys or ["language_model", "lm_head"],
lora_conflict_keys=lora_conflict_keys or [], lora_conflict_keys=lora_conflict_keys or [],
) )
...@@ -200,12 +200,12 @@ def patch_target_modules( ...@@ -200,12 +200,12 @@ def patch_target_modules(
_register_composite_model( _register_composite_model(
model_type="internvl", model_type="gemma3",
) )
_register_composite_model( _register_composite_model(
model_type="gemma3", model_type="internvl",
) )
...@@ -246,20 +246,19 @@ _register_composite_model( ...@@ -246,20 +246,19 @@ _register_composite_model(
lora_conflict_keys=["audio_projection_layer"], lora_conflict_keys=["audio_projection_layer"],
) )
_register_composite_model( _register_composite_model(
model_type="paligemma", model_type="mistral3",
) )
_register_composite_model( _register_composite_model(
model_type="video_llava", model_type="mllama",
vision_model_keys=["vision_model"],
) )
_register_composite_model( _register_composite_model(
model_type="mllama", model_type="paligemma",
vision_model_keys=["vision_model"],
) )
...@@ -282,7 +281,9 @@ _register_composite_model( ...@@ -282,7 +281,9 @@ _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"], language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
...@@ -291,6 +292,13 @@ _register_composite_model( ...@@ -291,6 +292,13 @@ _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"], language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
_register_composite_model(
model_type="video_llava",
)
...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any ...@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from peft import PeftModel from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers import GenerationMixin, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
...@@ -169,7 +169,7 @@ def patch_model( ...@@ -169,7 +169,7 @@ def patch_model(
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str( if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__ model.generate.__func__
): ):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(GenerationMixin.generate, model)
if add_valuehead: if add_valuehead:
prepare_valuehead_model(model) prepare_valuehead_model(model)
......
...@@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -80,6 +80,7 @@ class CustomDPOTrainer(DPOTrainer):
self.ftx_gamma = finetuning_args.pref_ftx self.ftx_gamma = finetuning_args.pref_ftx
self.label_smoothing = finetuning_args.dpo_label_smoothing self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma self.simpo_gamma = finetuning_args.simpo_gamma
self.ld_alpha = finetuning_args.ld_alpha
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
...@@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -177,7 +178,7 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
...@@ -187,7 +188,9 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -187,7 +188,9 @@ class CustomDPOTrainer(DPOTrainer):
batch = nested_detach(batch, clone=True) # avoid error batch = nested_detach(batch, clone=True) # avoid error
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=batch["labels"], ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length
...@@ -217,7 +220,9 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -217,7 +220,9 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch) reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(
ref_model, batch, is_ref_model=True
)
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps
......
...@@ -585,7 +585,10 @@ def create_custom_scheduler( ...@@ -585,7 +585,10 @@ def create_custom_scheduler(
def get_batch_logps( def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX logits: "torch.Tensor",
labels: "torch.Tensor",
label_pad_token_id: int = IGNORE_INDEX,
ld_alpha: Optional[float] = None,
) -> tuple["torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor"]:
r"""Compute the log probabilities of the given labels under the given logits. r"""Compute the log probabilities of the given labels under the given logits.
...@@ -602,7 +605,30 @@ def get_batch_logps( ...@@ -602,7 +605,30 @@ def get_batch_logps(
loss_mask = labels != label_pad_token_id loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0 # dummy token labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
valid_length = loss_mask.sum(-1)
if ld_alpha is not None:
num_examples = labels.shape[0] // 2
chosen_lengths = valid_length[:num_examples]
rejected_lengths = valid_length[num_examples:]
min_lengths = torch.min(chosen_lengths, rejected_lengths)
start_positions = torch.argmax(loss_mask.int(), dim=1)
public_lengths = start_positions + torch.cat([min_lengths, min_lengths], dim=0)
seq_len = labels.shape[-1]
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
ld_mask = position_ids < public_lengths.unsqueeze(1)
front_mask = (ld_mask * loss_mask).float()
rear_mask = (~ld_mask * loss_mask).float()
front_logps = (per_token_logps * front_mask).sum(-1)
rear_logps = (per_token_logps * rear_mask).sum(-1)
logps = front_logps + ld_alpha * rear_logps
else:
logps = (per_token_logps * loss_mask).sum(-1)
return logps, valid_length
def nested_detach( def nested_detach(
......
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoModelForVision2Seq
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
...@@ -72,12 +73,17 @@ def test_base_collator(): ...@@ -72,12 +73,17 @@ def test_base_collator():
def test_multimodal_collator(): def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args( model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"} {"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
) )
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq( data_collator = MultiModalDataCollatorForSeq2Seq(
template=template, template=template,
model=model,
pad_to_multiple_of=4, pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX, label_pad_token_id=IGNORE_INDEX,
**tokenizer_module, **tokenizer_module,
...@@ -107,8 +113,15 @@ def test_multimodal_collator(): ...@@ -107,8 +113,15 @@ def test_multimodal_collator():
"labels": [ "labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q], [0, 1, 2, 3, q, q, q, q, q, q, q, q],
], ],
"position_ids": [
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
**tokenizer_module["processor"].image_processor(fake_image), **tokenizer_module["processor"].image_processor(fake_image),
} }
assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys(): for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all() assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
...@@ -150,3 +163,7 @@ def test_4d_attention_mask(): ...@@ -150,3 +163,7 @@ def test_4d_attention_mask():
) )
assert list(attention_mask_computed.size()) == [2, 1, 6, 6] assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
assert torch.all(attention_mask_computed == attention_mask_expected) assert torch.all(attention_mask_computed == attention_mask_expected)
if __name__ == "__main__":
test_multimodal_collator()
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pytest import pytest
import torch import torch
from transformers import AutoConfig, AutoModelForVision2Seq from transformers import AutoConfig, AutoModelForVision2Seq
...@@ -76,3 +78,25 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool): ...@@ -76,3 +78,25 @@ def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
assert (visual_param_name in trainable_params) != freeze_vision_tower assert (visual_param_name in trainable_params) != freeze_vision_tower
assert (language_param_name in trainable_params) != freeze_language_model assert (language_param_name in trainable_params) != freeze_language_model
assert (merger_param_name in trainable_params) is False assert (merger_param_name in trainable_params) is False
def test_visual_model_save_load():
# check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments(finetuning_type="full")
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
loaded_model_weight = dict(model.named_parameters())
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False)
saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False)
if is_transformers_version_greater_than("4.52.0"):
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
else:
assert "model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
assert "model.layers.0.self_attn.q_proj.weight" in saved_model_weight
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.3.107 0.9.3.108
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment