Commit 581d366d authored by chenych's avatar chenych
Browse files

Support GLM-4/GLM-4-0414/GLM-Z1

parent 428c5813
...@@ -24,7 +24,6 @@ import torch.nn.functional as F ...@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.misc import get_current_device
from ..extras.packages import is_pillow_available from ..extras.packages import is_pillow_available
...@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype ...@@ -65,30 +64,19 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
where `o` equals to `0.0`, `x` equals to `min_dtype`. where `o` equals to `0.0`, `x` equals to `min_dtype`.
""" """
_, seq_len = attention_mask_with_indices.size() _, seq_len = attention_mask_with_indices.size()
# Move to compute device if the source is CPU.
source_device = attention_mask_with_indices.device
compute_device = get_current_device() if source_device.type == "cpu" else source_device
if compute_device != source_device:
attention_mask_with_indices = attention_mask_with_indices.to(compute_device)
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
zero_tensor = torch.tensor(0, dtype=dtype, device=compute_device) zero_tensor = torch.tensor(0, dtype=dtype)
# Create a non-padding mask. # Create a non-padding mask.
non_padding = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2) non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
# Create indices for comparison. # Create indices for comparison.
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len] indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, seq_len]
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1] indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3) # [bsz, 1, seq_len, 1]
# Create a lower triangular mask. # Create a lower triangular mask.
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=compute_device)) tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
attention_mask_4d = (indices == indices_t) & non_padding & tril_mask attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
# Invert the attention mask. # Invert the attention mask.
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype) attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
# Move back to original device if needed.
if compute_device != source_device:
attention_mask_4d = attention_mask_4d.to(source_device)
return attention_mask_4d return attention_mask_4d
...@@ -196,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): ...@@ -196,6 +184,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid") rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None) feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None: if feature_attention_mask is not None:
audio_feature_lengths = torch.sum( audio_feature_lengths = torch.sum(
...@@ -309,8 +298,9 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): ...@@ -309,8 +298,9 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch["kl_input_ids"] = kl_batch["input_ids"] batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"] batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"] batch["kl_labels"] = kl_batch["labels"]
if "cross_attention_mask" in kl_batch: # for mllama inputs. if "cross_attention_mask" in kl_batch: # for mllama inputs
batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"] batch["kl_cross_attention_mask"] = kl_batch["cross_attention_mask"]
if "token_type_ids" in kl_batch: if "token_type_ids" in kl_batch:
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Optional, TypedDict, Union from typing import TYPE_CHECKING, Optional, TypedDict, Union
import fsspec
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras import logging from ..extras import logging
...@@ -138,3 +140,50 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu ...@@ -138,3 +140,50 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
dataset_module["train_dataset"] = dataset dataset_module["train_dataset"] = dataset
return dataset_module return dataset_module
def setup_fs(path, anon=False):
"""Set up a filesystem object based on the path protocol."""
storage_options = {"anon": anon} if anon else {}
if path.startswith("s3://"):
fs = fsspec.filesystem("s3", **storage_options)
elif path.startswith(("gs://", "gcs://")):
fs = fsspec.filesystem("gcs", **storage_options)
else:
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'")
return fs
def read_cloud_json(cloud_path):
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path : str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
try:
# Try with anonymous access first
fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception:
# Try again with credentials
fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
def _read_json_with_fs(fs, path, lines=True):
"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if lines:
# Read JSONL (JSON Lines) format - one JSON object per line
data = [json.loads(line) for line in f if line.strip()]
else:
# Read regular JSON format
data = json.load(f)
return data
...@@ -16,13 +16,13 @@ import os ...@@ -16,13 +16,13 @@ import os
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
from datasets import load_dataset, load_from_disk from datasets import Dataset, load_dataset, load_from_disk
from ..extras import logging from ..extras import logging
from ..extras.constants import FILEEXT2TYPE from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data from ..extras.misc import check_version, has_tokenized_data
from .converter import align_dataset from .converter import align_dataset
from .data_utils import get_dataset_module, merge_dataset, split_dataset from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset
from .parser import get_dataset_list from .parser import get_dataset_list
from .processor import ( from .processor import (
FeedbackDatasetProcessor, FeedbackDatasetProcessor,
...@@ -67,6 +67,9 @@ def _load_single_dataset( ...@@ -67,6 +67,9 @@ def _load_single_dataset(
data_name = dataset_attr.subset data_name = dataset_attr.subset
data_dir = dataset_attr.folder data_dir = dataset_attr.folder
elif dataset_attr.load_from == "cloud_file":
data_path = dataset_attr.dataset_name
elif dataset_attr.load_from == "file": elif dataset_attr.load_from == "file":
data_files = [] data_files = []
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
...@@ -122,6 +125,8 @@ def _load_single_dataset( ...@@ -122,6 +125,8 @@ def _load_single_dataset(
token=model_args.om_hub_token, token=model_args.om_hub_token,
streaming=data_args.streaming, streaming=data_args.streaming,
) )
elif dataset_attr.load_from == "cloud_file":
dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split)
else: else:
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
......
...@@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin): ...@@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin):
return mm_inputs return mm_inputs
@dataclass
class KimiVLPlugin(BasePlugin):
@override
def process_messages(self, messages, images, videos, audios, processor):
self._validate_input(processor, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
num_image_tokens = 0
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length = math.prod(image_processor.merge_kernel_size)
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_hws):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
1,
)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass @dataclass
class Llama4Plugin(BasePlugin): class Llama4Plugin(BasePlugin):
@override @override
...@@ -493,8 +528,8 @@ class Llama4Plugin(BasePlugin): ...@@ -493,8 +528,8 @@ class Llama4Plugin(BasePlugin):
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens: if self.expand_mm_tokens:
placeholder_count = content.count(IMAGE_PLACEHOLDER)
prompt_splits = content.split(IMAGE_PLACEHOLDER) prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = [] new_content = []
for local_image_index, split_part in enumerate(prompt_splits): for local_image_index, split_part in enumerate(prompt_splits):
...@@ -507,6 +542,8 @@ class Llama4Plugin(BasePlugin): ...@@ -507,6 +542,8 @@ class Llama4Plugin(BasePlugin):
new_content.append(tokens_for_this_image) new_content.append(tokens_for_this_image)
content = "".join(new_content) content = "".join(new_content)
else:
content = content.replace(IMAGE_PLACEHOLDER, self.image_token)
message["content"] = content message["content"] = content
...@@ -1376,6 +1413,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1376,6 +1413,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else: else:
mm_inputs = {} mm_inputs = {}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
use_audio_in_video = getattr(processor, "use_audio_in_video", False) use_audio_in_video = getattr(processor, "use_audio_in_video", False)
...@@ -1396,16 +1434,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1396,16 +1434,16 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if audio_lengths is None: if audio_lengths is None:
raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.") raise ValueError("audio_lengths should exist when use_audio_in_video is `True`.")
if not mm_inputs.get("video_grid_thw", None): if mm_inputs.get("video_grid_thw", None) is None:
raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.") raise ValueError("video_grid_thw should exist when use_audio_in_video is `True`.")
positions_list = [] positions_list = []
for i, message in enumerate(messages): # get multimodal index when use_audio for message in messages: # get multimodal index when use_audio
positions = [] positions = []
for special_token in [self.audio_token, self.image_token, self.video_token]: for special_token in [self.audio_token, self.image_token, self.video_token]:
start = 0 start = 0
while True: while True:
pos = message[i].find(special_token, start) pos = message["content"].find(special_token, start)
if pos == -1: if pos == -1:
break break
positions.append((pos, special_token)) positions.append((pos, special_token))
...@@ -1417,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1417,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
content = message["content"] content = message["content"]
# separate with audio-video # separate with audio-video
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
content = content.replace( content = content.replace(
IMAGE_PLACEHOLDER, IMAGE_PLACEHOLDER,
...@@ -1427,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1427,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if not use_audio_in_video: if not use_audio_in_video:
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audio_lengths):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_token_replace_length = audio_lengths[num_audio_tokens] audio_token_replace_length = audio_lengths[num_audio_tokens]
content = content.replace( content = content.replace(
AUDIO_PLACEHOLDER, AUDIO_PLACEHOLDER,
...@@ -1437,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1437,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
# TODO handle video_input and use_audio_in_video # TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1 VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
...@@ -1445,14 +1492,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1445,14 +1492,17 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else: # if use the audio of video # deal video token and audio token togather else: # if use the audio of video # deal video token and audio token togather
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = ( video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0]) torch.arange(video_grid_thw[num_video_tokens][0])
.view(-1, 1, 1) .view(-1, 1, 1)
.expand( .expand(
-1, -1,
video_grid_thw[num_video_tokens][1] // self.image_processor.merge_size, video_grid_thw[num_video_tokens][1] // image_processor.merge_size,
video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size, video_grid_thw[num_video_tokens][2] // image_processor.merge_size,
) )
.flatten() .flatten()
* mm_inputs["video_second_per_grid"][num_video_tokens] * mm_inputs["video_second_per_grid"][num_video_tokens]
...@@ -1460,18 +1510,19 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): ...@@ -1460,18 +1510,19 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
).long() ).long()
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk) audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
placeholder_string = "" placeholder_string = ""
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
if video_chunk_index is not None: if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None: if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1) content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1 num_audio_tokens += 1
...@@ -1552,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin): ...@@ -1552,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"gemma3": Gemma3Plugin, "gemma3": Gemma3Plugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin, "llama4": Llama4Plugin,
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin, "llava_next": LlavaNextPlugin,
......
...@@ -141,6 +141,8 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li ...@@ -141,6 +141,8 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
elif "cloud_file_name" in dataset_info[name]:
dataset_attr = DatasetAttr("cloud_file", dataset_name=dataset_info[name]["cloud_file_name"])
else: else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
......
...@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -164,28 +164,28 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs = defaultdict(list) model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len) knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], [] packed_images, packed_videos, packed_audios = [], [], []
for i, length in enumerate(knapsack): for i, length in enumerate(knapsack):
index = length2indexes[length].pop() index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index] packed_input_ids += batch_input_ids[index]
packed_position_ids += list(range(len(batch_input_ids[index]))) # NOTE: pad_to_multiple_of ignore this
packed_labels += batch_labels[index] packed_labels += batch_labels[index]
packed_images += batch_images[index] packed_images += batch_images[index]
packed_videos += batch_videos[index] packed_videos += batch_videos[index]
packed_audios += batch_audios[index] packed_audios += batch_audios[index]
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
packed_position_ids += list(range(len(batch_input_ids[index])))
else: else:
packed_attention_masks += [1] * len(batch_input_ids[index]) packed_attention_masks += [1] * len(batch_input_ids[index])
if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask if len(packed_input_ids) < self.data_args.cutoff_len + 1: # avoid flash_attn drops attn mask
pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1 pad_length = self.data_args.cutoff_len - len(packed_input_ids) + 1
packed_input_ids += [self.tokenizer.pad_token_id] * pad_length packed_input_ids += [self.tokenizer.pad_token_id] * pad_length
packed_position_ids += [0] * pad_length
packed_labels += [IGNORE_INDEX] * pad_length packed_labels += [IGNORE_INDEX] * pad_length
if self.data_args.neat_packing: if self.data_args.neat_packing:
packed_attention_masks += [0] * pad_length packed_attention_masks += [0] * pad_length
packed_position_ids += [0] * pad_length
else: else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn packed_attention_masks += [1] * pad_length # more efficient flash_attn
...@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): ...@@ -194,10 +194,10 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs["input_ids"].append(packed_input_ids) model_inputs["input_ids"].append(packed_input_ids)
model_inputs["attention_mask"].append(packed_attention_masks) model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels) model_inputs["labels"].append(packed_labels)
model_inputs["images"].append(packed_images or None) model_inputs["images"].append(packed_images or None)
model_inputs["videos"].append(packed_videos or None) model_inputs["videos"].append(packed_videos or None)
model_inputs["audios"].append(packed_audios or None) model_inputs["audios"].append(packed_audios or None)
model_inputs["position_ids"].append(packed_position_ids or None)
return model_inputs return model_inputs
...@@ -923,6 +923,20 @@ register_template( ...@@ -923,6 +923,20 @@ register_template(
) )
register_template(
name="kimi_vl",
format_user=StringFormatter(
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant",
stop_words=["<|im_end|>"],
thought_words=("◁think▷", "◁/think▷"),
mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
)
register_template( register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
...@@ -1370,7 +1384,7 @@ register_template( ...@@ -1370,7 +1384,7 @@ register_template(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"] slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum, unique
from typing import Optional from typing import Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
...@@ -115,6 +115,19 @@ class DownloadSource(str, Enum): ...@@ -115,6 +115,19 @@ class DownloadSource(str, Enum):
OPENMIND = "om" OPENMIND = "om"
@unique
class QuantizationMethod(str, Enum):
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BNB = "bnb"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
class RopeScaling(str, Enum): class RopeScaling(str, Enum):
LINEAR = "linear" LINEAR = "linear"
DYNAMIC = "dynamic" DYNAMIC = "dynamic"
...@@ -133,6 +146,7 @@ def register_model_group( ...@@ -133,6 +146,7 @@ def register_model_group(
any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
): ):
DEFAULT_TEMPLATE[name] = template DEFAULT_TEMPLATE[name] = template
if multimodal: if multimodal:
MULTIMODAL_SUPPORTED_MODELS.add(name) MULTIMODAL_SUPPORTED_MODELS.add(name)
...@@ -711,6 +725,26 @@ register_model_group( ...@@ -711,6 +725,26 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m", DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
}, },
"GLM-4-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-4-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-9B-0414",
},
"GLM-4-32B-0414": {
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-Base-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-Base-0414",
},
"GLM-4-32B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
},
"GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
},
"GLM-Z1-32B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
},
}, },
template="glm4", template="glm4",
) )
...@@ -941,6 +975,22 @@ register_model_group( ...@@ -941,6 +975,22 @@ register_model_group(
) )
register_model_group(
models={
"Kimi-VL-A3B-Instruct": {
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Instruct",
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Instruct",
},
"Kimi-VL-A3B-Thinking": {
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking",
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking",
},
},
template="kimi_vl",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"LingoWhale-8B": { "LingoWhale-8B": {
......
...@@ -89,10 +89,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -89,10 +89,10 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.4.1") check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.5.2") check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.0") check_version("peft>=0.14.0,<=0.15.1")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
...@@ -177,6 +177,8 @@ def get_peak_memory() -> tuple[int, int]: ...@@ -177,6 +177,8 @@ def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes).""" r"""Get the peak memory usage for the current device (in Bytes)."""
if is_torch_npu_available(): if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_cuda_available(): elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else: else:
...@@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": ...@@ -200,7 +202,7 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
def is_gpu_or_npu_available() -> bool: def is_gpu_or_npu_available() -> bool:
r"""Check if the GPU or NPU is available.""" r"""Check if the GPU or NPU is available."""
return is_torch_npu_available() or is_torch_cuda_available() return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
def is_env_enabled(env_var: str, default: str = "0") -> bool: def is_env_enabled(env_var: str, default: str = "0") -> bool:
......
...@@ -160,5 +160,11 @@ class DataArguments: ...@@ -160,5 +160,11 @@ class DataArguments:
if self.mask_history and self.train_on_prompt: if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.") raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
if self.neat_packing:
self.packing = True
if self.packing:
self.cutoff_len -= 1 # avoid pad_to_multiple_of, needs improve
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
return asdict(self) return asdict(self)
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from typing_extensions import Self from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, RopeScaling from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
@dataclass @dataclass
...@@ -184,8 +184,8 @@ class BaseModelArguments: ...@@ -184,8 +184,8 @@ class BaseModelArguments:
class QuantizationArguments: class QuantizationArguments:
r"""Arguments pertaining to the quantization method.""" r"""Arguments pertaining to the quantization method."""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( quantization_method: QuantizationMethod = field(
default="bitsandbytes", default=QuantizationMethod.BNB,
metadata={"help": "Quantization method to use for on-the-fly quantization."}, metadata={"help": "Quantization method to use for on-the-fly quantization."},
) )
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
......
...@@ -135,7 +135,7 @@ def _check_extra_dependencies( ...@@ -135,7 +135,7 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM: if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.2") check_version("vllm>=0.4.3,<=0.8.4")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG: elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4") check_version("sglang>=0.4.4")
...@@ -285,10 +285,6 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -285,10 +285,6 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
...@@ -394,8 +390,10 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -394,8 +390,10 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
# Setup logging
_set_transformers_logging() _set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft": if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.") raise ValueError("vLLM engine only supports auto-regressive models.")
...@@ -412,6 +410,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -412,6 +410,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)
# Post-process model arguments
if model_args.export_dir is not None and model_args.export_device == "cpu": if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")} model_args.device_map = {"": torch.device("cpu")}
if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default
...@@ -425,8 +424,10 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -425,8 +424,10 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
# Setup logging
_set_transformers_logging() _set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
......
...@@ -46,6 +46,10 @@ class RayArguments: ...@@ -46,6 +46,10 @@ class RayArguments:
default="PACK", default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."}, metadata={"help": "The placement strategy for Ray training. Default is PACK."},
) )
ray_init_kwargs: Optional[dict] = field(
default=None,
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
)
def __post_init__(self): def __post_init__(self):
self.use_ray = use_ray() self.use_ray = use_ray()
......
...@@ -97,12 +97,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -97,12 +97,13 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args) patch_processor(processor, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.debug(f"Processor was not found: {e}.") logger.debug(f"Failed to load processor: {e}.")
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
if processor is not None and "Processor" not in processor.__class__.__name__: if processor is not None and "Processor" not in processor.__class__.__name__:
logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
processor = None processor = None
return {"tokenizer": tokenizer, "processor": processor} return {"tokenizer": tokenizer, "processor": processor}
......
...@@ -45,7 +45,7 @@ def apply_liger_kernel( ...@@ -45,7 +45,7 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text": elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
if model_type == "paligemma": elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
......
...@@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if model_type in ["kimi_vl", "deepseek_v3"]:
check_version("transformers>=4.51.1")
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
_set_z3_leaf_modules(model, [DeepseekV3MoE])
if model_type == "mixtral": if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import os import os
import random import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
...@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled ...@@ -28,7 +27,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ...extras import logging from ...extras import logging
from ...extras.constants import FILEEXT2TYPE from ...extras.constants import FILEEXT2TYPE, QuantizationMethod
from ...extras.misc import check_version, get_current_device from ...extras.misc import check_version, get_current_device
...@@ -41,19 +40,6 @@ if TYPE_CHECKING: ...@@ -41,19 +40,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.""" r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if os.path.isfile(model_args.export_quantization_dataset): if os.path.isfile(model_args.export_quantization_dataset):
...@@ -145,7 +131,7 @@ def configure_quantization( ...@@ -145,7 +131,7 @@ def configure_quantization(
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_method == QuantizationMethod.BNB:
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
check_version("bitsandbytes>=0.37.0", mandatory=True) check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
...@@ -173,7 +159,7 @@ def configure_quantization( ...@@ -173,7 +159,7 @@ def configure_quantization(
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.") logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value: elif model_args.quantization_method == QuantizationMethod.HQQ:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
...@@ -185,7 +171,7 @@ def configure_quantization( ...@@ -185,7 +171,7 @@ def configure_quantization(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance ) # use ATEN kernel (axis=0) for performance
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.") logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value: elif model_args.quantization_method == QuantizationMethod.EETQ:
if model_args.quantization_bit != 8: if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.") raise ValueError("EETQ only accepts 8-bit quantization.")
......
...@@ -79,6 +79,7 @@ def patch_processor( ...@@ -79,6 +79,7 @@ def patch_processor(
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate) setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
def patch_config( def patch_config(
...@@ -95,7 +96,8 @@ def patch_config( ...@@ -95,7 +96,8 @@ def patch_config(
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available(): if is_torch_npu_available():
torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE")) # avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
...@@ -115,6 +117,10 @@ def patch_config( ...@@ -115,6 +117,10 @@ def patch_config(
setattr(config, "init_audio", True) setattr(config, "init_audio", True)
setattr(config, "init_tts", False) setattr(config, "init_tts", False)
# replace the top-k gating method
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy")
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
......
...@@ -91,7 +91,13 @@ def run_dpo( ...@@ -91,7 +91,13 @@ def run_dpo(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/accuracies"]) keys = ["loss", "rewards/accuracies"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
......
...@@ -147,6 +147,9 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -147,6 +147,9 @@ class CustomKTOTrainer(KTOTrainer):
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
if "image_sizes" in batch:
model_inputs["image_sizes"] = batch["image_sizes"]
if "image_grid_thw" in batch: if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"] model_inputs["image_grid_thw"] = batch["image_grid_thw"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment