Commit 317a82e2 authored by chenych's avatar chenych
Browse files

Add QWQ-32B

parent 37b0ad9f
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -201,7 +201,7 @@ def _setup_lora_tuning( ...@@ -201,7 +201,7 @@ def _setup_lora_tuning(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
target_modules = patch_target_modules(model.config, finetuning_args, target_modules) target_modules = patch_target_modules(model, finetuning_args, target_modules)
if ( if (
finetuning_args.use_dora finetuning_args.use_dora
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,7 +16,14 @@ import os ...@@ -16,7 +16,14 @@ import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
)
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
...@@ -86,20 +93,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -86,20 +93,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception as e: except Exception as e:
raise OSError("Failed to load tokenizer.") from e raise OSError("Failed to load tokenizer.") from e
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: patch_tokenizer(tokenizer, model_args)
tokenizer.model_max_length = model_args.model_max_length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args) patch_processor(processor, config, tokenizer, model_args)
...@@ -155,6 +149,8 @@ def load_model( ...@@ -155,6 +149,8 @@ def load_model(
else: else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
load_class = AutoModelForSeq2SeqLM
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -85,12 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable ...@@ -85,12 +85,18 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: "torch.nn.Module" = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()): if any(param.requires_grad for param in module.parameters()):
has_grad = True
for arg in args: for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg): if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True) arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
return gradient_checkpointing_func(func, *args, **kwargs) if has_grad:
return gradient_checkpointing_func(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return custom_gradient_checkpointing_func return custom_gradient_checkpointing_func
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -27,6 +27,39 @@ if TYPE_CHECKING: ...@@ -27,6 +27,39 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def apply_liger_kernel_to_qwen2_5_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
from liger_kernel.transformers import LigerCrossEntropyLoss, LigerRMSNorm, LigerSwiGLUMLP
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
def get_dtype(self: "modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"):
return self.dtype
modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel.get_dtype = get_dtype
if rope:
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
if rms_norm:
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
def apply_liger_kernel( def apply_liger_kernel(
config: "PretrainedConfig", config: "PretrainedConfig",
model_args: "ModelArguments", model_args: "ModelArguments",
...@@ -47,19 +80,23 @@ def apply_liger_kernel( ...@@ -47,19 +80,23 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral": elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "phi3": elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2": elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl":
apply_liger_kernel = apply_liger_kernel_to_qwen2_5_vl
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info_rank0("Current training stage does not support chunked cross entropy.") logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False} kwargs = {"fused_linear_cross_entropy": False, "cross_entropy": True}
else: else:
kwargs = {} kwargs = {}
......
...@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward( ...@@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
check_version("transformers>=4.41.2,<=4.46.1") check_version("transformers>=4.41.2,<4.48.0")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -77,7 +77,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n ...@@ -77,7 +77,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
): ):
module_names.append(name) module_names.append(name)
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info_rank0("Apply lora to layers: {}.".format(",".join(map(str, trainable_layer_ids))))
return module_names return module_names
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -61,7 +61,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -61,7 +61,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if model_type == "qwen2moe": if model_type == "qwen2_moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
......
...@@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None: ...@@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn: if not is_trainable or not model_args.block_diag_attn:
return return
check_version("transformers>=4.43.0,<=4.46.1") check_version("transformers>=4.43.0")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.") logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
...@@ -39,6 +39,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ ...@@ -39,6 +39,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
rope_kwargs = {}
if model_args.model_max_length is not None: if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic": if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning_rank0( logger.warning_rank0(
...@@ -50,14 +51,21 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ ...@@ -50,14 +51,21 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
if current_max_length and model_args.model_max_length > current_max_length: if current_max_length and model_args.model_max_length > current_max_length:
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length) setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
else: else:
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.") logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0 rope_kwargs["factor"] = 1.0
if model_args.rope_scaling == "dynamic":
rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == "llama3":
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
else: else:
scaling_factor = 2.0 rope_kwargs["factor"] = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs})
logger.info_rank0( logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}" f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
) )
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
import torch import torch
import transformers import transformers
...@@ -42,6 +42,7 @@ class CompositeModel: ...@@ -42,6 +42,7 @@ class CompositeModel:
projector_key: str projector_key: str
vision_model_keys: List[str] vision_model_keys: List[str]
language_model_keys: List[str] language_model_keys: List[str]
lora_conflict_keys: List[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."): for key in self.projector_key.split("."):
...@@ -58,15 +59,14 @@ def _register_composite_model( ...@@ -58,15 +59,14 @@ def _register_composite_model(
projector_key: Optional[str] = None, projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None, vision_model_keys: Optional[List[str]] = None,
language_model_keys: Optional[List[str]] = None, language_model_keys: Optional[List[str]] = None,
lora_conflict_keys: Optional[List[str]] = None,
): ):
projector_key = projector_key or "multi_modal_projector"
vision_model_keys = vision_model_keys or ["vision_tower"]
language_model_keys = language_model_keys or ["language_model"]
COMPOSITE_MODELS[model_type] = CompositeModel( COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type, model_type=model_type,
projector_key=projector_key, projector_key=projector_key or "multi_modal_projector",
vision_model_keys=vision_model_keys, vision_model_keys=vision_model_keys or ["vision_tower"],
language_model_keys=language_model_keys, language_model_keys=language_model_keys or ["language_model"],
lora_conflict_keys=lora_conflict_keys or [],
) )
...@@ -210,29 +210,25 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P ...@@ -210,29 +210,25 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
def patch_target_modules( def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]: ) -> List[str]:
r""" r"""
Freezes vision tower for VLM LoRA tuning. Freezes vision tower for VLM LoRA tuning.
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(model.config, "model_type", None)
vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None) if model_type in COMPOSITE_MODELS:
if finetuning_args.freeze_vision_tower: forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
if model_type in COMPOSITE_MODELS: forbidden_modules.update(COMPOSITE_MODELS[model_type].lora_conflict_keys)
vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys module_names = []
logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.") for name, _ in model.named_modules():
vision_model_keys = "|".join(vision_model_keys) if any(target_module in name for target_module in target_modules) and not any(
target_modules = "|".join(target_modules) forbidden_module in name for forbidden_module in forbidden_modules
return f"^(?!.*{vision_model_keys}).*(?:{target_modules}).*" ):
else: module_names.append(name)
return target_modules
return module_names
else: else:
if model_type == "qwen2_vl": # avoid attaching lora to Conv3D layer return target_modules
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif vit_model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
_register_composite_model( _register_composite_model(
...@@ -252,6 +248,7 @@ _register_composite_model( ...@@ -252,6 +248,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="minicpmv", model_type="minicpmv",
projector_key="resampler",
vision_model_keys=["vpm"], vision_model_keys=["vpm"],
language_model_keys=["llm"], language_model_keys=["llm"],
) )
...@@ -259,8 +256,10 @@ _register_composite_model( ...@@ -259,8 +256,10 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="minicpmo", model_type="minicpmo",
vision_model_keys=["vpm", "apm", "resampler", "tts"], projector_key="resampler",
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
language_model_keys=["llm"], language_model_keys=["llm"],
lora_conflict_keys=["audio_projection_layer"],
) )
...@@ -280,9 +279,25 @@ _register_composite_model( ...@@ -280,9 +279,25 @@ _register_composite_model(
) )
_register_composite_model(
model_type="qwen2_audio",
vision_model_keys=["audio_tower"],
)
_register_composite_model( _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen2_5_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
) )
# Copyright 2024 the LlamaFactory team. # Copyright 2025 the LlamaFactory team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
...@@ -23,7 +22,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled ...@@ -23,7 +22,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging from ..extras import logging
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype, is_env_enabled
from ..extras.packages import is_transformers_version_greater_than from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
...@@ -53,10 +52,23 @@ if TYPE_CHECKING: ...@@ -53,10 +52,23 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
def patch_processor( def patch_processor(
processor: "ProcessorMixin", processor: "ProcessorMixin",
...@@ -65,13 +77,16 @@ def patch_processor( ...@@ -65,13 +77,16 @@ def patch_processor(
model_args: "ModelArguments", model_args: "ModelArguments",
) -> None: ) -> None:
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config)) if getattr(config, "vision_config", None) is not None: # visual models
setattr(processor, "image_resolution", model_args.image_resolution) setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "patch_size", get_patch_size(config, processor)) setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "video_resolution", model_args.video_resolution) setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor)) setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config( def patch_config(
...@@ -88,8 +103,7 @@ def patch_config( ...@@ -88,8 +103,7 @@ def patch_config(
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available(): if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE"))
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
...@@ -112,7 +126,7 @@ def patch_config( ...@@ -112,7 +126,7 @@ def patch_config(
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo": if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False) setattr(config, "init_audio", True)
setattr(config, "init_tts", False) setattr(config, "init_tts", False)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
......
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla attention implementation.")
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
logger.info("Upcasting lm_head outputs in float32.")
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(_fp32_forward_post_hook)
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment