Commit 0722acf1 authored by chenych's avatar chenych
Browse files

Update 0604

parent c4ba4563
......@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [
"python3 -m sglang.launch_server",
......@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
......@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
......@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params,
"stream": True,
}
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
......
......@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
......
......@@ -73,7 +73,7 @@ def main():
"help": partial(print, USAGE),
}
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
......
......@@ -51,12 +51,27 @@ class DatasetConverter:
else:
medias = medias[:]
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str):
for i in range(len(medias)):
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])):
medias[i] = os.path.join(self.data_args.media_dir, medias[i])
else:
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.")
if self.dataset_attr.load_from in ["script", "file"]:
if isinstance(medias[0], str):
for i in range(len(medias)):
media_path = os.path.join(self.data_args.media_dir, medias[i])
if os.path.isfile(media_path):
medias[i] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
)
elif isinstance(medias[0], list): # for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for i in range(len(medias)):
for j in range(len(medias[i])):
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
if os.path.isfile(media_path):
medias[i][j] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
)
return medias
......
......@@ -14,7 +14,7 @@
import json
from enum import Enum, unique
from typing import TYPE_CHECKING, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
import fsspec
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
......@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
return dataset_module
def setup_fs(path, anon=False):
"""Set up a filesystem object based on the path protocol."""
def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
r"""Set up a filesystem object based on the path protocol."""
storage_options = {"anon": anon} if anon else {}
if path.startswith("s3://"):
fs = fsspec.filesystem("s3", **storage_options)
elif path.startswith(("gs://", "gcs://")):
fs = fsspec.filesystem("gcs", **storage_options)
else:
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'")
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
if not fs.exists(path):
raise ValueError(f"Path does not exist: {path}.")
return fs
def read_cloud_json(cloud_path):
"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
r"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if path.endswith(".jsonl"):
return [json.loads(line) for line in f if line.strip()]
else:
return json.load(f)
def read_cloud_json(cloud_path: str) -> list[Any]:
r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path : str
cloud_path: str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
"""
try:
# Try with anonymous access first
fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
fs = setup_fs(cloud_path, anon=True) # try with anonymous access first
except Exception:
# Try again with credentials
fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
fs = setup_fs(cloud_path) # try again with credentials
def _read_json_with_fs(fs, path, lines=True):
"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if lines:
# Read JSONL (JSON Lines) format - one JSON object per line
data = [json.loads(line) for line in f if line.strip()]
else:
# Read regular JSON format
data = json.load(f)
# filter out non-JSON files
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
return data
return sum([_read_json_with_fs(fs, file) for file in files], [])
......@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
......@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
if return_dict:
return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor(
......@@ -300,13 +300,18 @@ def get_dataset(
raise ValueError("Turn off `streaming` when saving dataset to disk.")
# Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"):
with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
)
with training_args.main_process_first(desc="pre-process dataset"):
with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
)
......
This diff is collapsed.
......@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list: list[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
......
This diff is collapsed.
......@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text = ""
tool_names = []
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", ""
......@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
@override
@staticmethod
......@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
......@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
......@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "parameters" not in tool:
tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template."""
......@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
wrapped_tools.append(tool if tool.get("type") == "function" else {"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return "[" + ", ".join(function_texts) + "]"
return json.dumps(
[{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
)
@override
@staticmethod
......@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError:
return content
if not isinstance(tools, list):
tools = [tools]
results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content
class QwenToolUtils(ToolUtils):
......@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
......@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return "\n".join(function_texts)
function_texts = [
json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
for name, arguments in functions
]
return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
@override
@staticmethod
......
......@@ -513,7 +513,7 @@ register_model_group(
register_model_group(
models={
"DeepSeek-V2-236B-Chat-0628": {
"DeepSeek-V2-236B-0628-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat-0628",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat-0628",
},
......@@ -521,7 +521,7 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5",
},
"DeepSeek-V2.5-236B-Chat-1210": {
"DeepSeek-V2.5-236B-1210-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
},
......@@ -533,6 +533,17 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
},
"DeepSeek-V3-671B-0324-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
},
},
template="deepseek3",
)
register_model_group(
models={
"DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
......@@ -545,6 +556,10 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
},
"DeepSeek-R1-8B-0528-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
},
"DeepSeek-R1-14B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
......@@ -565,8 +580,12 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
},
"DeepSeek-R1-671B-0528-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-0528",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-0528",
},
},
template="deepseek3",
template="deepseekr1",
)
......@@ -673,6 +692,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
},
"MedGemma-27B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-27b-text-it",
DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it",
},
},
template="gemma",
)
......@@ -704,6 +727,14 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
},
"MedGemma-4B": {
DownloadSource.DEFAULT: "google/medgemma-4b-pt",
DownloadSource.MODELSCOPE: "google/medgemma-4b-pt",
},
"MedGemma-4B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-4b-it",
DownloadSource.MODELSCOPE: "google/medgemma-4b-it",
},
},
template="gemma3",
multimodal=True,
......@@ -737,6 +768,13 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
},
},
template="glm4",
)
register_model_group(
models={
"GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
......@@ -746,7 +784,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
},
},
template="glm4",
template="glmz1",
)
......@@ -869,12 +907,13 @@ register_model_group(
register_model_group(
models={
"Granite-3.2-1B-A400M-Base": {
"Granite-Vision-3.2-2B": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
},
},
template="granite3_vision",
multimodal=True,
)
......@@ -1398,6 +1437,45 @@ register_model_group(
)
register_model_group(
models={
"MiMo-7B-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-Base",
},
"MiMo-7B-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-SFT",
},
"MiMo-7B-Instruct-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL",
},
"MiMo-7B-RL-ZERO": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL-ZERO",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL-ZERO",
},
},
template="mimo",
)
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
},
"MiMo-7B-VL-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
},
},
template="mimo_vl",
multimodal=True,
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
......@@ -2461,6 +2539,38 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
},
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
},
template="qwen3",
)
......@@ -2484,10 +2594,22 @@ register_model_group(
register_model_group(
models={
"Qwen2.5-Omni-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-3B",
},
"Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
}
},
"Qwen2.5-Omni-7B-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
},
"Qwen2.5-Omni-7B-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
},
},
template="qwen2_omni",
multimodal=True,
......@@ -2598,15 +2720,17 @@ register_model_group(
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
"Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
"Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
},
"Seed-Coder-8B-Instruct-Reasoning": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
},
},
template="solar",
template="seed_coder",
)
......@@ -2631,6 +2755,82 @@ register_model_group(
)
register_model_group(
models={
"SmolLM-135M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-135M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-135M",
},
"SmolLM-360M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-360M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-360M",
},
"SmolLM-1.7B": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-1.7B",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-1.7B",
},
"SmolLM-135M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-135M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-135M-Instruct",
},
"SmolLM-360M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-360M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-360M-Instruct",
},
"SmolLM-1.7B-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-1.7B-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-1.7B-Instruct",
},
},
template="smollm",
)
register_model_group(
models={
"SmolLM2-135M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-135M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-135M",
},
"SmolLM2-360M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-360M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-360M",
},
"SmolLM2-1.7B": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-1.7B",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-1.7B",
},
"SmolLM2-135M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-135M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-135M-Instruct",
},
"SmolLM2-360M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-360M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-360M-Instruct",
},
"SmolLM2-1.7B-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
},
},
template="smollm2",
)
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group(
models={
"StarCoder2-3B": {
......
......@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
import accelerate
......@@ -83,4 +84,9 @@ def print_env() -> None:
except Exception:
pass
if os.path.exists("data"):
info["Default data directory"] = "detected"
else:
info["Default data directory"] = "not detected"
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
......@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
hint = f"To fix: run `{pip_command}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.1")
check_version(
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......
......@@ -99,6 +99,10 @@ class DataArguments:
default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
)
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
......@@ -111,6 +115,14 @@ class DataArguments:
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field(
default=None,
metadata={
......@@ -121,6 +133,10 @@ class DataArguments:
)
},
)
data_shared_file_system: bool = field(
default=False,
metadata={"help": "Whether or not to use a shared file system for the datasets."},
)
def __post_init__(self):
def split_arg(arg):
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import GenerationConfig
......@@ -62,10 +62,6 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
......
......@@ -235,10 +235,6 @@ class ProcessorArguments:
default=False,
metadata={"help": "Whether to crop the image to patches for internvl."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},
......@@ -255,6 +251,10 @@ class ProcessorArguments:
default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field(
default=16000,
metadata={"help": "The sampling rate of audio inputs."},
......@@ -364,6 +364,12 @@ class SGLangArguments:
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
......
......@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.4")
check_version("vllm>=0.4.3,<=0.8.6")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4")
check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True)
if finetuning_args.use_galore:
......
......@@ -64,6 +64,7 @@ class RayArguments:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
......
......@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2:
......
......@@ -45,16 +45,24 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
......@@ -63,6 +71,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
else:
logger.warning_rank0("Current model does not support liger kernel.")
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment