Commit 581d366d authored by chenych's avatar chenych
Browse files

Support GLM-4/GLM-4-0414/GLM-Z1

parent 428c5813
......@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available
......@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
_, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device)
zero_tensor = torch.tensor(0, dtype=dtype)
# Create a non-padding mask.
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
# Create indices for comparison.
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
# Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d
......@@ -196,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
......@@ -309,8 +298,9 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
if "cross_attention_mask" in kl_batch: # for mllama inputs.
if "cross_attention_mask" in kl_batch: # for mllama inputs
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
......
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import Enum, unique
from typing import TYPE_CHECKING, Optional, TypedDict, Union
import fsspec
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras import logging
......@@ -138,3 +140,50 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
dataset_module["train_dataset"] = dataset
return dataset_module
def setup_fs(path, anon=False):
"""Set up a filesystem object based on the path protocol."""
storage_options = {"anon": anon} if anon else {}
if path.startswith("s3://"):
fs = fsspec.filesystem("s3", **storage_options)
elif path.startswith(("gs://", "gcs://")):
fs = fsspec.filesystem("gcs", **storage_options)
else:
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'")
return fs
def read_cloud_json(cloud_path):
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path : str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
try:
# Try with anonymous access first
fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception:
# Try again with credentials
fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
def _read_json_with_fs(fs, path, lines=True):
"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if lines:
# Read JSONL (JSON Lines) format - one JSON object per line
data = [json.loads(line) for line in f if line.strip()]
else:
# Read regular JSON format
data = json.load(f)
return data
......@@ -16,13 +16,13 @@ import os
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from datasets import load_dataset, load_from_disk
from datasets import Dataset, load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data
from .converter import align_dataset
from .data_utils import get_dataset_module, merge_dataset, split_dataset
from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset
from .parser import get_dataset_list
from .processor import (
FeedbackDatasetProcessor,
......@@ -67,6 +67,9 @@ def _load_single_dataset(
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "cloud_file":
data_path = dataset_attr.dataset_name
elif dataset_attr.load_from == "file":
data_files = []
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
......@@ -122,6 +125,8 @@ def _load_single_dataset(
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
elif dataset_attr.load_from == "cloud_file":
dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split)
else:
dataset = load_dataset(
path=data_path,
......
......@@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin):
return mm_inputs
@dataclass
class KimiVLPlugin(BasePlugin):
@override
def process_messages(self, messages, images, videos, audios, processor):
self._validate_input(processor, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
num_image_tokens = 0
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length = math.prod(image_processor.merge_kernel_size)
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_hws):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
1,
)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass
class Llama4Plugin(BasePlugin):
@override
......@@ -493,8 +528,8 @@ class Llama4Plugin(BasePlugin):
messages = deepcopy(messages)
for message in messages:
content = message["content"]
placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens:
placeholder_count = content.count(IMAGE_PLACEHOLDER)
prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = []
for local_image_index, split_part in enumerate(prompt_splits):
......@@ -507,6 +542,8 @@ class Llama4Plugin(BasePlugin):
new_content.append(tokens_for_this_image)
content = "".join(new_content)
else:
content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
message["content"] = content
......@@ -1376,6 +1413,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else:
mm_inputs = {}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
......@@ -1396,16 +1434,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_lengths is None:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if not mm_inputs.get("video_grid_thw", None):
if mm_inputs.get("video_grid_thw", None) is None:
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio
for message in messages: # get multimodal index when use_audio
positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0
while True:
pos = message[i].find(special_token, start)
pos = message["content"].find(special_token, start)
if pos == -1:
break
positions.append((pos, special_token))
......@@ -1417,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
content = message["content"]
# separate with audio-video
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
content = content.replace(
IMAGE_PLACEHOLDER,
......@@ -1427,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if not use_audio_in_video:
while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audio_lengths):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_token_replace_length = audio_lengths[num_audio_tokens]
content = content.replace(
AUDIO_PLACEHOLDER,
......@@ -1437,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
# TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
......@@ -1445,14 +1492,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else: # if use the audio of video # deal video token and audio token togather
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0])
.view(-1, 1, 1)
.expand(
-1,
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
)
.flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens]
......@@ -1460,18 +1510,19 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
......@@ -1552,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
......
......@@ -141,6 +141,8 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
elif "cloud_file_name" in dataset_info[name]:
dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"])
else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
......
......@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this
packed_labels += batch_labels[index]
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else:
packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_position_ids += [0] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
......@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs
......@@ -923,6 +923,20 @@ register_template(
)
register_template(
name="kimi_vl",
format_user=StringFormatter(
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant",
stop_words=["<|im_end|>"],
thought_words=("◁think▷", "◁/think▷"),
mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
)
register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
......@@ -1370,7 +1384,7 @@ register_template(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"],
)
......
......@@ -14,7 +14,7 @@
import os
from collections import OrderedDict, defaultdict
from enum import Enum
from enum import Enum, unique
from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
......@@ -115,6 +115,19 @@ class DownloadSource(str, Enum):
OPENMIND = "om"
@unique
class QuantizationMethod(str, Enum):
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BNB = "bnb"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
class RopeScaling(str, Enum):
LINEAR = "linear"
DYNAMIC = "dynamic"
......@@ -133,6 +146,7 @@ def register_model_group(
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
):
DEFAULT_TEMPLATE[name] = template
if multimodal:
MULTIMODAL_SUPPORTED_MODELS.add(name)
......@@ -711,6 +725,26 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
},
"GLM-4-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-4-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-9B-0414",
},
"GLM-4-32B-0414": {
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-Base-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-Base-0414",
},
"GLM-4-32B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
},
"GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
},
"GLM-Z1-32B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
},
},
template="glm4",
)
......@@ -941,6 +975,22 @@ register_model_group(
)
register_model_group(
models={
"Kimi-VL-A3B-Instruct": {
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Instruct",
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Instruct",
},
"Kimi-VL-A3B-Thinking": {
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking",
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking",
},
},
template="kimi_vl",
multimodal=True,
)
register_model_group(
models={
"LingoWhale-8B": {
......
......@@ -89,10 +89,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.4.1")
check_version("accelerate>=0.34.0,<=1.5.2")
check_version("peft>=0.14.0,<=0.15.0")
check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.1")
check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......@@ -177,6 +177,8 @@ def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
......@@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool:
r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available()
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool:
......
......@@ -160,5 +160,11 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
if self.neat_packing:
self.packing = True
if self.packing:
self.cutoff_len -= 1 # avoid pad_to_multiple_of, needs improve
def to_dict(self) -> dict[str, Any]:
return asdict(self)
......@@ -23,7 +23,7 @@ import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
@dataclass
......@@ -184,8 +184,8 @@ class BaseModelArguments:
class QuantizationArguments:
r"""Arguments pertaining to the quantization method."""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
quantization_method: QuantizationMethod = field(
default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
......
......@@ -135,7 +135,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.2")
check_version("vllm>=0.4.3,<=0.8.4")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4")
......@@ -285,10 +285,6 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
......@@ -394,8 +390,10 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging
_set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
......@@ -412,6 +410,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
# Post-process model arguments
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default
......@@ -425,8 +424,10 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging
_set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
......
......@@ -46,6 +46,10 @@ class RayArguments:
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
ray_init_kwargs: Optional[dict] = field(
default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
)
def __post_init__(self):
self.use_ray = use_ray()
......
......@@ -97,12 +97,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args)
except Exception as e:
logger.debug(f"Processor was not found: {e}.")
logger.debug(f"Failed to load processor: {e}.")
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if processor is not None and "Processor" not in processor.__class__.__name__:
logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
processor = None
return {"tokenizer": tokenizer, "processor": processor}
......
......@@ -45,7 +45,7 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
if model_type == "paligemma":
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
......
......@@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if model_type in ["kimi_vl", "deepseek_v3"]:
check_version("transformers>=4.51.1")
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
_set_z3_leaf_modules(model, [DeepseekV3MoE])
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
......
......@@ -18,7 +18,6 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any
import torch
......@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE
from ...extras.constants import FILEEXT2TYPE, QuantizationMethod
from ...extras.misc import check_version, get_current_device
......@@ -41,19 +40,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if os.path.isfile(model_args.export_quantization_dataset):
......@@ -145,7 +131,7 @@ def configure_quantization(
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_method == QuantizationMethod.BNB:
if model_args.quantization_bit == 8:
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
......@@ -173,7 +159,7 @@ def configure_quantization(
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
elif model_args.quantization_method == QuantizationMethod.HQQ:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
......@@ -185,7 +171,7 @@ def configure_quantization(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
elif model_args.quantization_method == QuantizationMethod.EETQ:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
......
......@@ -79,6 +79,7 @@ def patch_processor(
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
def patch_config(
......@@ -95,7 +96,8 @@ def patch_config(
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE"))
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
......@@ -115,6 +117,10 @@ def patch_config(
setattr(config, "init_audio", True)
setattr(config, "init_tts", False)
# replace the top-k gating method
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy")
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
......
......@@ -91,7 +91,13 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"])
keys = ["loss", "rewards/accuracies"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
# Evaluation
if training_args.do_eval:
......
......@@ -147,6 +147,9 @@ class CustomKTOTrainer(KTOTrainer):
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "image_sizes" in batch:
model_inputs["image_sizes"] = batch["image_sizes"]
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment