Commit 0722acf1 authored by chenych's avatar chenych
Browse files

Update 0604

parent c4ba4563
...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine): ...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [ launch_cmd = [
"python3 -m sglang.launch_server", "python3 -m sglang.launch_server",
...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine): ...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}", f"--download-dir {model_args.cache_dir}",
"--log-level error", "--log-level error",
] ]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd) launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try: try:
...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine): ...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine): ...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params, "sampling_params": sampling_params,
"stream": True, "stream": True,
} }
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
......
...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine): ...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
......
...@@ -73,7 +73,7 @@ def main(): ...@@ -73,7 +73,7 @@ def main():
"help": partial(print, USAGE), "help": partial(print, USAGE),
} }
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
......
...@@ -51,12 +51,27 @@ class DatasetConverter: ...@@ -51,12 +51,27 @@ class DatasetConverter:
else: else:
medias = medias[:] medias = medias[:]
if self.dataset_attr.load_from in ["script", "file"] and isinstance(medias[0], str): if self.dataset_attr.load_from in ["script", "file"]:
for i in range(len(medias)): if isinstance(medias[0], str):
if os.path.isfile(os.path.join(self.data_args.media_dir, medias[i])): for i in range(len(medias)):
medias[i] = os.path.join(self.data_args.media_dir, medias[i]) media_path = os.path.join(self.data_args.media_dir, medias[i])
else: if os.path.isfile(media_path):
logger.warning_rank0_once(f"Media {medias[i]} does not exist in `media_dir`. Use original path.") medias[i] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i]} does not exist in `media_dir`. Use original path."
)
elif isinstance(medias[0], list): # for processed video frames
# medias is a list of lists, e.g., [[frame1.jpg, frame2.jpg], [frame3.jpg, frame4.jpg]]
for i in range(len(medias)):
for j in range(len(medias[i])):
media_path = os.path.join(self.data_args.media_dir, medias[i][j])
if os.path.isfile(media_path):
medias[i][j] = media_path
else:
logger.warning_rank0_once(
f"Media {medias[i][j]} does not exist in `media_dir`. Use original path."
)
return medias return medias
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import json import json
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Optional, TypedDict, Union from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
import fsspec import fsspec
from datasets import DatasetDict, concatenate_datasets, interleave_datasets from datasets import DatasetDict, concatenate_datasets, interleave_datasets
...@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu ...@@ -142,48 +142,49 @@ def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModu
return dataset_module return dataset_module
def setup_fs(path, anon=False): def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
"""Set up a filesystem object based on the path protocol.""" r"""Set up a filesystem object based on the path protocol."""
storage_options = {"anon": anon} if anon else {} storage_options = {"anon": anon} if anon else {}
if path.startswith("s3://"): if path.startswith("s3://"):
fs = fsspec.filesystem("s3", **storage_options) fs = fsspec.filesystem("s3", **storage_options)
elif path.startswith(("gs://", "gcs://")): elif path.startswith(("gs://", "gcs://")):
fs = fsspec.filesystem("gcs", **storage_options) fs = fsspec.filesystem("gcs", **storage_options)
else: else:
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'") raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
if not fs.exists(path):
raise ValueError(f"Path does not exist: {path}.")
return fs return fs
def read_cloud_json(cloud_path): def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
"""Read a JSON/JSONL file from cloud storage (S3 or GCS). r"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if path.endswith(".jsonl"):
return [json.loads(line) for line in f if line.strip()]
else:
return json.load(f)
def read_cloud_json(cloud_path: str) -> list[Any]:
r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args: Args:
cloud_path : str cloud_path: str
Cloud path in the format: Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3 - 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage - 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
lines : bool, default=True
If True, read the file as JSON Lines format (one JSON object per line)
""" """
try: try:
# Try with anonymous access first fs = setup_fs(cloud_path, anon=True) # try with anonymous access first
fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception: except Exception:
# Try again with credentials fs = setup_fs(cloud_path) # try again with credentials
fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
def _read_json_with_fs(fs, path, lines=True): # filter out non-JSON files
"""Helper function to read JSON/JSONL files using fsspec.""" files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
with fs.open(path, "r") as f: files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
if lines: if not files:
# Read JSONL (JSON Lines) format - one JSON object per line raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
data = [json.loads(line) for line in f if line.strip()]
else:
# Read regular JSON format
data = json.load(f)
return data return sum([_read_json_with_fs(fs, file) for file in files], [])
...@@ -168,7 +168,7 @@ def _get_merged_dataset( ...@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
...@@ -181,10 +181,10 @@ def _get_merged_dataset( ...@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge: if return_dict:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor( def _get_dataset_processor(
...@@ -300,13 +300,18 @@ def get_dataset( ...@@ -300,13 +300,18 @@ def get_dataset(
raise ValueError("Turn off `streaming` when saving dataset to disk.") raise ValueError("Turn off `streaming` when saving dataset to disk.")
# Load and preprocess dataset # Load and preprocess dataset
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset( eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
) )
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset( dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
) )
......
This diff is collapsed.
...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li ...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list: list[DatasetAttr] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
......
This diff is collapsed.
...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
param_text = "" param_text = ""
for name, param in tool["parameters"]["properties"].items(): for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", "" required, enum, items = "", "", ""
...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = "" return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
@override @override
@staticmethod @staticmethod
...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils): ...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
) )
...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils): ...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text) return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils): ...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
raise ValueError("Llama-3 does not support parallel functions.") return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tool = json.loads(content.strip()) tools = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if "name" not in tool or "parameters" not in tool: tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template.""" r"""Mistral v0.3 tool using template."""
...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils): ...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = [] wrapped_tools = []
for tool in tools: for tool in tools:
wrapped_tools.append({"type": "function", "function": tool}) wrapped_tools.append(tool if tool.get("type") == "function" else {"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]" return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] return json.dumps(
for name, arguments in functions: [{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') )
return "[" + ", ".join(function_texts) + "]"
@override @override
@staticmethod @staticmethod
...@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils): ...@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if not isinstance(tools, list): tools = [tools] if not isinstance(tools, list) else tools
tools = [tools] try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
results = [] except KeyError:
for tool in tools: return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils): class QwenToolUtils(ToolUtils):
...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils): ...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text) return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils): ...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = [
for name, arguments in functions: json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
function_texts.append( for name, arguments in functions
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>" ]
) return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
return "\n".join(function_texts)
@override @override
@staticmethod @staticmethod
......
...@@ -513,7 +513,7 @@ register_model_group( ...@@ -513,7 +513,7 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"DeepSeek-V2-236B-Chat-0628": { "DeepSeek-V2-236B-0628-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat-0628", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat-0628",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat-0628", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat-0628",
}, },
...@@ -521,7 +521,7 @@ register_model_group( ...@@ -521,7 +521,7 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5",
}, },
"DeepSeek-V2.5-236B-Chat-1210": { "DeepSeek-V2.5-236B-1210-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
}, },
...@@ -533,6 +533,17 @@ register_model_group( ...@@ -533,6 +533,17 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
}, },
"DeepSeek-V3-671B-0324-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
},
},
template="deepseek3",
)
register_model_group(
models={
"DeepSeek-R1-1.5B-Distill": { "DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
...@@ -545,6 +556,10 @@ register_model_group( ...@@ -545,6 +556,10 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}, },
"DeepSeek-R1-8B-0528-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
},
"DeepSeek-R1-14B-Distill": { "DeepSeek-R1-14B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
...@@ -565,8 +580,12 @@ register_model_group( ...@@ -565,8 +580,12 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
}, },
"DeepSeek-R1-671B-0528-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-0528",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-0528",
},
}, },
template="deepseek3", template="deepseekr1",
) )
...@@ -673,6 +692,10 @@ register_model_group( ...@@ -673,6 +692,10 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-1b-it", DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
}, },
"MedGemma-27B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-27b-text-it",
DownloadSource.MODELSCOPE: "google/medgemma-27b-text-it",
},
}, },
template="gemma", template="gemma",
) )
...@@ -704,6 +727,14 @@ register_model_group( ...@@ -704,6 +727,14 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-3-27b-it", DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
}, },
"MedGemma-4B": {
DownloadSource.DEFAULT: "google/medgemma-4b-pt",
DownloadSource.MODELSCOPE: "google/medgemma-4b-pt",
},
"MedGemma-4B-Instruct": {
DownloadSource.DEFAULT: "google/medgemma-4b-it",
DownloadSource.MODELSCOPE: "google/medgemma-4b-it",
},
}, },
template="gemma3", template="gemma3",
multimodal=True, multimodal=True,
...@@ -737,6 +768,13 @@ register_model_group( ...@@ -737,6 +768,13 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414", DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
}, },
},
template="glm4",
)
register_model_group(
models={
"GLM-Z1-9B-0414-Chat": { "GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414", DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
...@@ -746,7 +784,7 @@ register_model_group( ...@@ -746,7 +784,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
}, },
}, },
template="glm4", template="glmz1",
) )
...@@ -869,12 +907,13 @@ register_model_group( ...@@ -869,12 +907,13 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Granite-3.2-1B-A400M-Base": { "Granite-Vision-3.2-2B": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b", DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b", DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
}, },
}, },
template="granite3_vision", template="granite3_vision",
multimodal=True,
) )
...@@ -1398,6 +1437,45 @@ register_model_group( ...@@ -1398,6 +1437,45 @@ register_model_group(
) )
register_model_group(
models={
"MiMo-7B-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-Base",
},
"MiMo-7B-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-SFT",
},
"MiMo-7B-Instruct-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL",
},
"MiMo-7B-RL-ZERO": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL-ZERO",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL-ZERO",
},
},
template="mimo",
)
register_model_group(
models={
"MiMo-7B-VL-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-SFT",
},
"MiMo-7B-VL-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-VL-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-VL-7B-RL",
},
},
template="mimo_vl",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"MiniCPM-2B-SFT-Chat": { "MiniCPM-2B-SFT-Chat": {
...@@ -2461,6 +2539,38 @@ register_model_group( ...@@ -2461,6 +2539,38 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B", DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B", DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
}, },
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
}, },
template="qwen3", template="qwen3",
) )
...@@ -2484,10 +2594,22 @@ register_model_group( ...@@ -2484,10 +2594,22 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen2.5-Omni-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-3B",
},
"Qwen2.5-Omni-7B": { "Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
} },
"Qwen2.5-Omni-7B-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
},
"Qwen2.5-Omni-7B-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
},
}, },
template="qwen2_omni", template="qwen2_omni",
multimodal=True, multimodal=True,
...@@ -2598,15 +2720,17 @@ register_model_group( ...@@ -2598,15 +2720,17 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B-v1.0": { "Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
}, },
"SOLAR-10.7B-Instruct-v1.0": { "Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", },
"Seed-Coder-8B-Instruct-Reasoning": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
}, },
}, },
template="solar", template="seed_coder",
) )
...@@ -2631,6 +2755,82 @@ register_model_group( ...@@ -2631,6 +2755,82 @@ register_model_group(
) )
register_model_group(
models={
"SmolLM-135M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-135M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-135M",
},
"SmolLM-360M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-360M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-360M",
},
"SmolLM-1.7B": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-1.7B",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-1.7B",
},
"SmolLM-135M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-135M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-135M-Instruct",
},
"SmolLM-360M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-360M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-360M-Instruct",
},
"SmolLM-1.7B-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM-1.7B-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM-1.7B-Instruct",
},
},
template="smollm",
)
register_model_group(
models={
"SmolLM2-135M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-135M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-135M",
},
"SmolLM2-360M": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-360M",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-360M",
},
"SmolLM2-1.7B": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-1.7B",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-1.7B",
},
"SmolLM2-135M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-135M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-135M-Instruct",
},
"SmolLM2-360M-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-360M-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-360M-Instruct",
},
"SmolLM2-1.7B-Instruct": {
DownloadSource.DEFAULT: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
DownloadSource.MODELSCOPE: "HuggingFaceTB/SmolLM2-1.7B-Instruct",
},
},
template="smollm2",
)
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group( register_model_group(
models={ models={
"StarCoder2-3B": { "StarCoder2-3B": {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import platform import platform
import accelerate import accelerate
...@@ -83,4 +84,9 @@ def print_env() -> None: ...@@ -83,4 +84,9 @@ def print_env() -> None:
except Exception: except Exception:
pass pass
if os.path.exists("data"):
info["Default data directory"] = "detected"
else:
info["Default data directory"] = "not detected"
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory: if mandatory:
hint = f"To fix: run `pip install {requirement}`." hint = f"To fix: run `{pip_command}`."
else: else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint) require_version(requirement, hint)
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version(
check_version("datasets>=2.16.0,<=3.5.0") "transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
check_version("accelerate>=0.34.0,<=1.6.0") )
check_version("peft>=0.14.0,<=0.15.1") check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......
...@@ -99,6 +99,10 @@ class DataArguments: ...@@ -99,6 +99,10 @@ class DataArguments:
default=0.0, default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
) )
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field( packing: Optional[bool] = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
...@@ -111,6 +115,14 @@ class DataArguments: ...@@ -111,6 +115,14 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
...@@ -121,6 +133,10 @@ class DataArguments: ...@@ -121,6 +133,10 @@ class DataArguments:
) )
}, },
) )
data_shared_file_system: bool = field(
default=False,
metadata={"help": "Whether or not to use a shared file system for the datasets."},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional from typing import Any
from transformers import GenerationConfig from transformers import GenerationConfig
...@@ -62,10 +62,6 @@ class GeneratingArguments: ...@@ -62,10 +62,6 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field( skip_special_tokens: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
......
...@@ -235,10 +235,6 @@ class ProcessorArguments: ...@@ -235,10 +235,6 @@ class ProcessorArguments:
default=False, default=False,
metadata={"help": "Whether to crop the image to patches for internvl."}, metadata={"help": "Whether to crop the image to patches for internvl."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field( video_max_pixels: int = field(
default=256 * 256, default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."}, metadata={"help": "The maximum number of pixels of video inputs."},
...@@ -255,6 +251,10 @@ class ProcessorArguments: ...@@ -255,6 +251,10 @@ class ProcessorArguments:
default=128, default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."}, metadata={"help": "The maximum number of sampled frames for video inputs."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field( audio_sampling_rate: int = field(
default=16000, default=16000,
metadata={"help": "The sampling rate of audio inputs."}, metadata={"help": "The sampling rate of audio inputs."},
...@@ -364,6 +364,12 @@ class SGLangArguments: ...@@ -364,6 +364,12 @@ class SGLangArguments:
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"): if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
......
...@@ -148,10 +148,10 @@ def _check_extra_dependencies( ...@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM: if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.4") check_version("vllm>=0.4.3,<=0.8.6")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG: elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4") check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True) check_version("sglang", mandatory=True)
if finetuning_args.use_galore: if finetuning_args.use_galore:
......
...@@ -64,6 +64,7 @@ class RayArguments: ...@@ -64,6 +64,7 @@ class RayArguments:
raise ValueError( raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
) )
import pyarrow.fs as fs import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3": if self.ray_storage_filesystem == "s3":
......
...@@ -29,10 +29,8 @@ if TYPE_CHECKING: ...@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool if getattr(config, "model_type", None) == "gemma2":
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2: if model_args.flash_attn != AttentionFunction.FA2:
......
...@@ -45,16 +45,24 @@ def apply_liger_kernel( ...@@ -45,16 +45,24 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text": elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "paligemma": elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral": elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral": elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama": elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3": elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2": elif model_type == "qwen2":
...@@ -63,6 +71,8 @@ def apply_liger_kernel( ...@@ -63,6 +71,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl": elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment