Commit 27a7ad86 authored by luopl's avatar luopl
Browse files

update to v0.9.1

parent 731cf9b8
...@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) ...@@ -28,17 +28,22 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
r""" r"""
Finds all available modules to apply lora or galore. Finds all available modules to apply lora or galore.
""" """
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"} forbidden_modules = {"lm_head"}
if model_type == "chatglm":
if model.config.model_type == "chatglm":
forbidden_modules.add("output_layer") forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2": elif model_type == "internlm2":
forbidden_modules.add("output") forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]: elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector") forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")
if freeze_vision_tower: if freeze_vision_tower:
forbidden_modules.add("vision_tower") if model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
......
...@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
return return
if getattr(model.config, "model_type", None) == "dbrx": model_type = getattr(model.config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN from transformers.models.dbrx.modeling_dbrx import DbrxFFN
_set_z3_leaf_modules(model, [DbrxFFN]) _set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba": if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
_set_z3_leaf_modules(model, [JambaSparseMoeBlock]) _set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe": if model_type == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral": if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) _set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe": if model_type == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None: if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]: if model_type in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek": elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "jetmoe": elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable) setattr(config, "output_router_logits", is_trainable)
...@@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor ...@@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
if is_transformers_version_greater_than_4_43(): if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch import torch
import transformers.models import transformers.models
...@@ -28,7 +28,7 @@ from ...extras.logging import get_logger ...@@ -28,7 +28,7 @@ from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -80,24 +80,120 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL): ...@@ -80,24 +80,120 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
self.act = ACT2FN[projector_hidden_act] self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype( def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector" r"""
) -> None: Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook( def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(model_args.compute_dtype) return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else:
return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None: def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models r"""
Patches VLMs before loading them.
"""
model_type = getattr(config, "model_type", None)
if model_type in [
"llava",
"llava_next",
"llava_next_video",
"paligemma",
"video_llava",
]: # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.") logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
elif model_type == "qwen2_vl":
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual")
if finetuning_args.train_mm_proj_only:
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
return vision_feature_select_strategy
def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
else:
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
...@@ -33,11 +33,17 @@ from .model_utils.packing import configure_packing ...@@ -33,11 +33,17 @@ from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model from .model_utils.visual import (
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments from ..hparams import ModelArguments
...@@ -51,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: ...@@ -51,6 +57,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_processor(
processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
) -> None:
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
def patch_config( def patch_config(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
...@@ -88,6 +110,9 @@ def patch_config( ...@@ -88,6 +110,9 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
...@@ -129,11 +154,9 @@ def patch_model( ...@@ -129,11 +154,9 @@ def patch_model(
if model_args.resize_vocab: if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer) resize_embedding_layer(model, tokenizer)
if model_args.visual_inputs:
autocast_projector_dtype(model, model_args)
if is_trainable: if is_trainable:
prepare_model_for_training(model, model_args) prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model) add_z3_leaf_module(model)
if not model_args.use_unsloth: if not model_args.use_unsloth:
......
...@@ -32,9 +32,11 @@ from transformers.utils import ( ...@@ -32,9 +32,11 @@ from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
is_safetensors_available, is_safetensors_available,
) )
from typing_extensions import override
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory
if is_safetensors_available(): if is_safetensors_available():
...@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint( ...@@ -73,8 +75,8 @@ def fix_valuehead_checkpoint(
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
decoder_state_dict = {} os.remove(path_to_checkpoint)
v_head_state_dict = {} decoder_state_dict, v_head_state_dict = {}, {}
for name, param in state_dict.items(): for name, param in state_dict.items():
if name.startswith("v_head."): if name.startswith("v_head."):
v_head_state_dict[name] = param v_head_state_dict[name] = param
...@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint( ...@@ -90,43 +92,52 @@ def fix_valuehead_checkpoint(
else: else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
os.remove(path_to_checkpoint)
logger.info("Value head model saved at: {}".format(output_dir)) logger.info("Value head model saved at: {}".format(output_dir))
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
safe_serialization=args.save_safetensors,
) )
class SaveProcessorCallback(TrainerCallback): class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
def __init__(self, processor: "ProcessorMixin") -> None: def __init__(self, processor: "ProcessorMixin") -> None:
r"""
Initializes a callback for saving the processor.
"""
self.processor = processor self.processor = processor
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(self.processor, "image_processor").save_pretrained(output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save: if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir) getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback):
r""" r"""
Initializes a callback for converting the PiSSA adapter to a normal one. A callback for converting the PiSSA adapter to a normal one.
""" """
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
...@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback): ...@@ -141,10 +152,8 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save: if args.should_save:
model = kwargs.pop("model") model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
...@@ -172,21 +181,22 @@ class PissaConvertCallback(TrainerCallback): ...@@ -172,21 +181,22 @@ class PissaConvertCallback(TrainerCallback):
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
def __init__(self) -> None: def __init__(self) -> None:
r""" # Progress
Initializes a callback for logging training and evaluation status.
"""
""" Progress """
self.start_time = 0 self.start_time = 0
self.cur_steps = 0 self.cur_steps = 0
self.max_steps = 0 self.max_steps = 0
self.elapsed_time = "" self.elapsed_time = ""
self.remaining_time = "" self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None self.thread_pool: Optional["ThreadPoolExecutor"] = None
""" Status """ # Status
self.aborted = False self.aborted = False
self.do_train = False self.do_train = False
""" Web UI """ # Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
...@@ -226,10 +236,8 @@ class LogCallback(TrainerCallback): ...@@ -226,10 +236,8 @@ class LogCallback(TrainerCallback):
self.thread_pool.shutdown(wait=True) self.thread_pool.shutdown(wait=True)
self.thread_pool = None self.thread_pool = None
@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if ( if (
args.should_save args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
...@@ -238,55 +246,41 @@ class LogCallback(TrainerCallback): ...@@ -238,55 +246,41 @@ class LogCallback(TrainerCallback):
logger.warning("Previous trainer log in this folder will be deleted.") logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save: if args.should_save:
self.do_train = True self.do_train = True
self._reset(max_steps=state.max_steps) self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir) self._create_thread_pool(output_dir=args.output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool() self._close_thread_pool()
@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.aborted: if self.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if self.aborted: if self.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if not self.do_train: if not self.do_train:
self._close_thread_pool() self._close_thread_pool()
@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
"""
if not self.do_train: if not self.do_train:
self._close_thread_pool() self._close_thread_pool()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
"""
if not args.should_save: if not args.should_save:
return return
...@@ -304,26 +298,31 @@ class LogCallback(TrainerCallback): ...@@ -304,26 +298,31 @@ class LogCallback(TrainerCallback):
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
) )
if state.num_input_tokens_seen:
logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
logs["total_tokens"] = state.num_input_tokens_seen
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info( logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
) )
) )
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
@override
def on_prediction_step( def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
): ):
r"""
Event called after a prediction step.
"""
if self.do_train: if self.do_train:
return return
......
...@@ -26,6 +26,7 @@ import torch.nn.functional as F ...@@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
...@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
...@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return losses, chosen_rewards, rejected_rewards return losses, chosen_rewards, rejected_rewards
@override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
...@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -176,7 +180,6 @@ class CustomDPOTrainer(DPOTrainer):
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length
...@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -187,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length, _ = valid_length.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
...@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -208,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps
@override
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
...@@ -41,13 +41,15 @@ def run_dpo( ...@@ -41,13 +41,15 @@ def run_dpo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = PairwiseDataCollatorWithPadding( data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
) )
# Create reference model # Create reference model
...@@ -60,7 +62,7 @@ def run_dpo( ...@@ -60,7 +62,7 @@ def run_dpo(
ref_model = None ref_model = None
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from transformers import Trainer from transformers import Trainer
from trl import KTOTrainer from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
...@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r""" r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
""" """
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:
...@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -127,17 +132,20 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids": batch["{}input_ids".format(prefix)], "input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)], "attention_mask": batch["{}attention_mask".format(prefix)],
} }
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
if "{}token_type_ids".format(prefix) in batch: if "image_grid_thw" in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length return logps, logps / valid_length
@override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
...@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -153,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
@override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
...@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -173,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@override
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import KTODataCollatorWithPadding, get_dataset from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
...@@ -41,13 +41,15 @@ def run_kto( ...@@ -41,13 +41,15 @@ def run_kto(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="kto", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = KTODataCollatorWithPadding( data_collator = KTODataCollatorWithPadding(
tokenizer=tokenizer, template=template,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
) )
# Create reference model # Create reference model
...@@ -57,7 +59,7 @@ def run_kto( ...@@ -57,7 +59,7 @@ def run_kto(
ref_model = create_ref_model(model_args, finetuning_args) ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = CustomKTOTrainer( trainer = CustomKTOTrainer(
......
...@@ -31,7 +31,7 @@ if TYPE_CHECKING: ...@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
r""" r"""
Gets reward scores from the API server. Gets reward scores from the API server.
""" """
...@@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d ...@@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
r""" r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
""" """
...@@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: ...@@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
return layer_norm_params return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
r""" r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
""" """
......
...@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME ...@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
...@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.callback_handler.on_train_end(self.args, self.state, self.control) self.callback_handler.on_train_end(self.args, self.state, self.control)
@override
def create_optimizer( def create_optimizer(
self, self,
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
...@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return optimizer return optimizer
@override
def create_scheduler( def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer" self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
...@@ -389,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -389,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
""" """
if self.finetuning_args.reward_model_type == "api": if self.finetuning_args.reward_model_type == "api":
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)] token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages) return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses) batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
...@@ -402,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -402,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, return_dict=True, use_cache=False) values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
...@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache() @PPODecorators.empty_device_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,
...@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch.cat(all_masks)[:, :-1], torch.cat(all_masks)[:, :-1],
) )
@override
def save_model(self, output_dir: Optional[str] = None) -> None: def save_model(self, output_dir: Optional[str] = None) -> None:
r""" r"""
Saves model checkpoint. Saves model checkpoint.
......
...@@ -17,9 +17,7 @@ ...@@ -17,9 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer
from ...data import get_dataset
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
...@@ -43,11 +41,12 @@ def run_ppo( ...@@ -43,11 +41,12 @@ def run_ppo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
......
...@@ -16,6 +16,7 @@ from types import MethodType ...@@ -16,6 +16,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
...@@ -55,11 +56,13 @@ class CustomTrainer(Trainer): ...@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
......
...@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorForLanguageModeling from transformers import DataCollatorForLanguageModeling
from ...data import get_dataset from ...data import get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push from ..trainer_utils import create_modelcard_and_push
...@@ -42,7 +42,8 @@ def run_pt( ...@@ -42,7 +42,8 @@ def run_pt(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
......
...@@ -26,6 +26,10 @@ if TYPE_CHECKING: ...@@ -26,6 +26,10 @@ if TYPE_CHECKING:
@dataclass @dataclass
class ComputeAccuracy: class ComputeAccuracy:
r"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]: def _dump(self) -> Optional[Dict[str, float]]:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):
......
...@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
...@@ -63,20 +64,23 @@ class PairwiseTrainer(Trainer): ...@@ -63,20 +64,23 @@ class PairwiseTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
...@@ -41,12 +41,13 @@ def run_rm( ...@@ -41,12 +41,13 @@ def run_rm(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module)
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
# Initialize our Trainer # Initialize our Trainer
trainer = PairwiseTrainer( trainer = PairwiseTrainer(
......
...@@ -45,6 +45,9 @@ if is_rouge_available(): ...@@ -45,6 +45,9 @@ if is_rouge_available():
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if isinstance(logits, (list, tuple)): if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size) if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0] logits = logits[0]
...@@ -59,6 +62,10 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor ...@@ -59,6 +62,10 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor
@dataclass @dataclass
class ComputeAccuracy: class ComputeAccuracy:
r"""
Computes accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]: def _dump(self) -> Optional[Dict[str, float]]:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):
...@@ -84,6 +91,8 @@ class ComputeAccuracy: ...@@ -84,6 +91,8 @@ class ComputeAccuracy:
@dataclass @dataclass
class ComputeSimilarity: class ComputeSimilarity:
r""" r"""
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer. Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
""" """
......
...@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union ...@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
...@@ -64,32 +65,36 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -64,32 +65,36 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def prediction_step( def prediction_step(
self, self,
model: "torch.nn.Module", model: "torch.nn.Module",
inputs: Dict[str, Union[torch.Tensor, Any]], inputs: Dict[str, Union["torch.Tensor", Any]],
prediction_loss_only: bool, prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None, ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""
Removes the prompt part in the generated tokens. Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels labels = inputs["labels"] if "labels" in inputs else None
if self.args.predict_with_generate: if self.args.predict_with_generate:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
labels = labels.detach().clone() if labels is not None else None # backup labels
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if prompt_len > label_len: if prompt_len > label_len:
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
...@@ -105,7 +110,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -105,7 +110,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return loss, generated_tokens, labels return loss, generated_tokens, labels
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: def _pad_tensors_to_target_len(self, src_tensor: "torch.Tensor", tgt_tensor: "torch.Tensor") -> "torch.Tensor":
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment