Commit 24534501 authored by mashun1's avatar mashun1
Browse files

parallel_tool

parent c4ba4563
...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine): ...@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [ launch_cmd = [
"python3 -m sglang.launch_server", "python3 -m sglang.launch_server",
...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine): ...@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}", f"--download-dir {model_args.cache_dir}",
"--log-level error", "--log-level error",
] ]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd) launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try: try:
...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine): ...@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine): ...@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params, "sampling_params": sampling_params,
"stream": True, "stream": True,
} }
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
......
...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine): ...@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
......
...@@ -73,7 +73,7 @@ def main(): ...@@ -73,7 +73,7 @@ def main():
"help": partial(print, USAGE), "help": partial(print, USAGE),
} }
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
......
...@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path): ...@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path):
try: try:
# Try with anonymous access first # Try with anonymous access first
fs = setup_fs(cloud_path, anon=True) fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception: except Exception:
# Try again with credentials # Try again with credentials
fs = setup_fs(cloud_path) fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
if fs.isdir(cloud_path):
files = [x["Key"] for x in fs.listdir(cloud_path)]
else:
files = [cloud_path]
# filter out non-JSON files
files = [file for file in files if file.endswith(".json") or file.endswith(".jsonl")]
if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}")
data = []
for file in files:
data.extend(_read_json_with_fs(fs, file, lines=file.endswith(".jsonl")))
return data
def _read_json_with_fs(fs, path, lines=True): def _read_json_with_fs(fs, path, lines=True):
......
...@@ -168,7 +168,7 @@ def _get_merged_dataset( ...@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True, return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: ) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format.""" r"""Return the merged datasets in the standard format."""
if dataset_names is None: if dataset_names is None:
...@@ -181,10 +181,10 @@ def _get_merged_dataset( ...@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge: if return_dict:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
return datasets return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor( def _get_dataset_processor(
...@@ -303,7 +303,12 @@ def get_dataset( ...@@ -303,7 +303,12 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset( eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
) )
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
......
This diff is collapsed.
...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li ...@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list: list[DatasetAttr] = [] dataset_list: list[DatasetAttr] = []
for name in dataset_names: for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope(): load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name) dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)
continue continue
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import re import re
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -51,6 +52,7 @@ class Template: ...@@ -51,6 +52,7 @@ class Template:
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
enable_thinking: Optional[bool]
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
...@@ -61,7 +63,7 @@ class Template: ...@@ -61,7 +63,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively.""" r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True) encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = [] prompt_ids = []
for encoded_ids in encoded_messages[:-1]: for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids prompt_ids += encoded_ids
...@@ -77,7 +79,7 @@ class Template: ...@@ -77,7 +79,7 @@ class Template:
tools: Optional[str] = None, tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively.""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False) encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
...@@ -92,6 +94,19 @@ class Template: ...@@ -92,6 +94,19 @@ class Template:
return list(stop_token_ids) return list(stop_token_ids)
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids.""" r"""Convert elements to token ids."""
token_ids = [] token_ids = []
...@@ -111,18 +126,12 @@ class Template: ...@@ -111,18 +126,12 @@ class Template:
return token_ids return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids. r"""Encode formatted inputs to pairs of token ids.
...@@ -140,18 +149,14 @@ class Template: ...@@ -140,18 +149,14 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2)) elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content) elements += self.format_function.apply(content=message["content"])
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -162,6 +167,9 @@ class Template: ...@@ -162,6 +167,9 @@ class Template:
@staticmethod @staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""Add or replace eos token to the tokenizer.""" r"""Add or replace eos token to the tokenizer."""
if tokenizer.eos_token == eos_token:
return
is_added = tokenizer.eos_token_id is None is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
...@@ -328,7 +336,6 @@ class Llama2Template(Template): ...@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
remove_thought: bool,
) -> list[list[int]]: ) -> list[list[int]]:
system = system or self.default_system system = system or self.default_system
encoded_messages = [] encoded_messages = []
...@@ -342,18 +349,14 @@ class Llama2Template(Template): ...@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER: if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content) elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT: elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content) elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION: elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content) elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION: elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content) elements += self.format_function.apply(content=message["content"])
else: else:
raise NotImplementedError("Unexpected role: {}".format(message["role"])) raise NotImplementedError("Unexpected role: {}".format(message["role"]))
...@@ -392,6 +395,64 @@ class Llama2Template(Template): ...@@ -392,6 +395,64 @@ class Llama2Template(Template):
return jinja_template return jinja_template
@dataclass
class ReasoningTemplate(Template):
r"""A template that add thought to assistant message."""
@override
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False: # remove all cot
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
prompt_ids += self.get_thought_word_ids(tokenizer)
else: # do compute loss
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
@override
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else: # do compute loss
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
TEMPLATES: dict[str, "Template"] = {} TEMPLATES: dict[str, "Template"] = {}
...@@ -410,6 +471,7 @@ def register_template( ...@@ -410,6 +471,7 @@ def register_template(
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
enable_thinking: Optional[bool] = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template, template_class: type["Template"] = Template,
) -> None: ) -> None:
...@@ -456,6 +518,7 @@ def register_template( ...@@ -456,6 +518,7 @@ def register_template(
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template, replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
...@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system): if len(user_slot) > len(user_slot_empty_system):
...@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = "" default_system = ""
return Template( return template_class(
format_user=StringFormatter(slots=[user_slot]), format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]), format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]), format_system=StringFormatter(slots=[system_slot]),
...@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": ...@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False, efficient_eos=False,
replace_eos=False, replace_eos=False,
replace_jinja_template=False, replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"), mm_plugin=get_mm_plugin(name="base"),
) )
...@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ...@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
if data_args.default_system is not None:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer) template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer) template.fix_jinja_template(tokenizer)
return template return template
...@@ -756,6 +826,7 @@ register_template( ...@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
), ),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -774,6 +845,15 @@ register_template( ...@@ -774,6 +845,15 @@ register_template(
) )
# copied from deepseek3 template
register_template(
name="deepseekr1",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=ReasoningTemplate,
)
register_template( register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
...@@ -838,6 +918,7 @@ register_template( ...@@ -838,6 +918,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -853,6 +934,7 @@ register_template( ...@@ -853,6 +934,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"), mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -872,6 +954,22 @@ register_template( ...@@ -872,6 +954,22 @@ register_template(
) )
# copied from glm4 template
register_template(
name="glmz1",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
template_class=ReasoningTemplate,
)
register_template( register_template(
name="granite3", name="granite3",
format_user=StringFormatter( format_user=StringFormatter(
...@@ -1018,6 +1116,7 @@ register_template( ...@@ -1018,6 +1116,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
) )
...@@ -1037,6 +1136,7 @@ register_template( ...@@ -1037,6 +1136,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"], stop_words=["<|eot|>", "<|eom|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"), mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
) )
...@@ -1066,6 +1166,7 @@ register_template( ...@@ -1066,6 +1166,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"), mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
) )
...@@ -1079,6 +1180,7 @@ register_template( ...@@ -1079,6 +1180,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant provided by Moonshot-AI.", default_system="You are a helpful assistant provided by Moonshot-AI.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1131,6 +1233,7 @@ register_template( ...@@ -1131,6 +1233,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"), format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"], stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
...@@ -1163,6 +1266,7 @@ register_template( ...@@ -1163,6 +1266,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"), mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
) )
...@@ -1233,6 +1337,24 @@ register_template( ...@@ -1233,6 +1337,24 @@ register_template(
) )
# copied from qwen template
register_template(
name="mimo",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
# copied from chatml template # copied from chatml template
register_template( register_template(
name="minicpm_v", name="minicpm_v",
...@@ -1363,6 +1485,7 @@ register_template( ...@@ -1363,6 +1485,7 @@ register_template(
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"], stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template, template_class=Llama2Template,
) )
...@@ -1374,6 +1497,7 @@ register_template( ...@@ -1374,6 +1497,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
...@@ -1384,6 +1508,7 @@ register_template( ...@@ -1384,6 +1508,7 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]), format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True,
) )
...@@ -1395,6 +1520,7 @@ register_template( ...@@ -1395,6 +1520,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1425,6 +1551,7 @@ register_template( ...@@ -1425,6 +1551,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
) )
...@@ -1440,6 +1567,8 @@ register_template( ...@@ -1440,6 +1567,8 @@ register_template(
), ),
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
) )
...@@ -1451,6 +1580,7 @@ register_template( ...@@ -1451,6 +1580,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"), mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
) )
...@@ -1468,6 +1598,7 @@ register_template( ...@@ -1468,6 +1598,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin( mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
), ),
...@@ -1486,6 +1617,7 @@ register_template( ...@@ -1486,6 +1617,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"), format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )
...@@ -1503,6 +1635,20 @@ register_template( ...@@ -1503,6 +1635,20 @@ register_template(
) )
register_template(
name="seed_coder",
format_user=StringFormatter(
slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
),
format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\n"
),
)
# copied from llama3 template # copied from llama3 template
register_template( register_template(
name="skywork_o1", name="skywork_o1",
......
...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
param_text = "" param_text = ""
for name, param in tool["parameters"]["properties"].items(): for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", "" required, enum, items = "", "", ""
...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils): ...@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = "" return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
@override @override
@staticmethod @staticmethod
...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils): ...@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
) )
...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils): ...@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text) return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils): ...@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1: function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
raise ValueError("Llama-3 does not support parallel functions.") return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try: try:
tool = json.loads(content.strip()) tools = json.loads(content.strip())
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if "name" not in tool or "parameters" not in tool: tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils): class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template.""" r"""Mistral v0.3 tool using template."""
...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils): ...@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = [] wrapped_tools = []
for tool in tools: for tool in tools:
wrapped_tools.append({"type": "function", "function": tool}) wrapped_tools.append(tool if tool.get("type") == "function" else {"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]" return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] return json.dumps(
for name, arguments in functions: [{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') )
return "[" + ", ".join(function_texts) + "]"
@override @override
@staticmethod @staticmethod
...@@ -256,18 +251,12 @@ class MistralToolUtils(ToolUtils): ...@@ -256,18 +251,12 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
if not isinstance(tools, list): tools = [tools] if not isinstance(tools, list) else tools
tools = [tools] try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
results = [] except KeyError:
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
class QwenToolUtils(ToolUtils): class QwenToolUtils(ToolUtils):
r"""Qwen 2.5 tool using template.""" r"""Qwen 2.5 tool using template."""
...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils): ...@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str: def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text) return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils): ...@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = [] function_texts = [
for name, arguments in functions: json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
function_texts.append( for name, arguments in functions
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>" ]
) return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
return "\n".join(function_texts)
@override @override
@staticmethod @staticmethod
......
...@@ -533,6 +533,17 @@ register_model_group( ...@@ -533,6 +533,17 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
}, },
"DeepSeek-V3-671B-0324-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
},
},
template="deepseek3",
)
register_model_group(
models={
"DeepSeek-R1-1.5B-Distill": { "DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
...@@ -566,7 +577,7 @@ register_model_group( ...@@ -566,7 +577,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
}, },
}, },
template="deepseek3", template="deepseekr1",
) )
...@@ -737,6 +748,13 @@ register_model_group( ...@@ -737,6 +748,13 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414", DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
}, },
},
template="glm4",
)
register_model_group(
models={
"GLM-Z1-9B-0414-Chat": { "GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414", DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
...@@ -746,7 +764,7 @@ register_model_group( ...@@ -746,7 +764,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414", DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
}, },
}, },
template="glm4", template="glmz1",
) )
...@@ -869,12 +887,13 @@ register_model_group( ...@@ -869,12 +887,13 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Granite-3.2-1B-A400M-Base": { "Granite-Vision-3.2-2B": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b", DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b", DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
}, },
}, },
template="granite3_vision", template="granite3_vision",
multimodal=True,
) )
...@@ -1398,6 +1417,29 @@ register_model_group( ...@@ -1398,6 +1417,29 @@ register_model_group(
) )
register_model_group(
models={
"MiMo-7B-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-Base",
},
"MiMo-7B-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-SFT",
},
"MiMo-7B-Instruct-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL",
},
"MiMo-7B-RL-ZERO": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL-ZERO",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL-ZERO",
},
},
template="mimo",
)
register_model_group( register_model_group(
models={ models={
"MiniCPM-2B-SFT-Chat": { "MiniCPM-2B-SFT-Chat": {
...@@ -2461,6 +2503,38 @@ register_model_group( ...@@ -2461,6 +2503,38 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B", DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B", DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
}, },
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
}, },
template="qwen3", template="qwen3",
) )
...@@ -2484,10 +2558,22 @@ register_model_group( ...@@ -2484,10 +2558,22 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Qwen2.5-Omni-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-3B",
},
"Qwen2.5-Omni-7B": { "Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B", DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
} },
"Qwen2.5-Omni-7B-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
},
"Qwen2.5-Omni-7B-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
},
}, },
template="qwen2_omni", template="qwen2_omni",
multimodal=True, multimodal=True,
...@@ -2598,15 +2684,17 @@ register_model_group( ...@@ -2598,15 +2684,17 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B-v1.0": { "Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
}, },
"SOLAR-10.7B-Instruct-v1.0": { "Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
}, },
"Seed-Coder-8B-Instruct-Reasoning": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
}, },
template="solar", },
template="seed_coder",
) )
...@@ -2631,6 +2719,20 @@ register_model_group( ...@@ -2631,6 +2719,20 @@ register_model_group(
) )
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group( register_model_group(
models={ models={
"StarCoder2-3B": { "StarCoder2-3B": {
......
...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory: if mandatory:
hint = f"To fix: run `pip install {requirement}`." hint = f"To fix: run `{pip_command}`."
else: else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint) require_version(requirement, hint)
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version(
check_version("datasets>=2.16.0,<=3.5.0") "transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
check_version("accelerate>=0.34.0,<=1.6.0") )
check_version("peft>=0.14.0,<=0.15.1") check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"): if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......
...@@ -99,6 +99,10 @@ class DataArguments: ...@@ -99,6 +99,10 @@ class DataArguments:
default=0.0, default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
) )
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field( packing: Optional[bool] = field(
default=None, default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
...@@ -111,6 +115,14 @@ class DataArguments: ...@@ -111,6 +115,14 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional from typing import Any
from transformers import GenerationConfig from transformers import GenerationConfig
...@@ -62,10 +62,6 @@ class GeneratingArguments: ...@@ -62,10 +62,6 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field( skip_special_tokens: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
......
...@@ -235,10 +235,6 @@ class ProcessorArguments: ...@@ -235,10 +235,6 @@ class ProcessorArguments:
default=False, default=False,
metadata={"help": "Whether to crop the image to patches for internvl."}, metadata={"help": "Whether to crop the image to patches for internvl."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field( video_max_pixels: int = field(
default=256 * 256, default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."}, metadata={"help": "The maximum number of pixels of video inputs."},
...@@ -255,6 +251,10 @@ class ProcessorArguments: ...@@ -255,6 +251,10 @@ class ProcessorArguments:
default=128, default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."}, metadata={"help": "The maximum number of sampled frames for video inputs."},
) )
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field( audio_sampling_rate: int = field(
default=16000, default=16000,
metadata={"help": "The sampling rate of audio inputs."}, metadata={"help": "The sampling rate of audio inputs."},
...@@ -364,6 +364,12 @@ class SGLangArguments: ...@@ -364,6 +364,12 @@ class SGLangArguments:
default=None, default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
) )
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"): if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
......
...@@ -148,10 +148,10 @@ def _check_extra_dependencies( ...@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True) check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM: if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.4") check_version("vllm>=0.4.3,<=0.8.6")
check_version("vllm", mandatory=True) check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG: elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4") check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True) check_version("sglang", mandatory=True)
if finetuning_args.use_galore: if finetuning_args.use_galore:
......
...@@ -64,6 +64,7 @@ class RayArguments: ...@@ -64,6 +64,7 @@ class RayArguments:
raise ValueError( raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}" f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
) )
import pyarrow.fs as fs import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3": if self.ray_storage_filesystem == "s3":
......
...@@ -29,10 +29,8 @@ if TYPE_CHECKING: ...@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def configure_attn_implementation( def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool if getattr(config, "model_type", None) == "gemma2":
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2: if model_args.flash_attn != AttentionFunction.FA2:
......
...@@ -45,16 +45,24 @@ def apply_liger_kernel( ...@@ -45,16 +45,24 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text": elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "paligemma": elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral": elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral": elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama": elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3": elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2": elif model_type == "qwen2":
...@@ -63,6 +71,8 @@ def apply_liger_kernel( ...@@ -63,6 +71,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl": elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
...@@ -99,8 +99,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -99,8 +99,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if model_type in [ if model_type in [
"dbrx", "dbrx",
"granitemoe", "granitemoe",
...@@ -113,7 +115,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t ...@@ -113,7 +115,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
"qwen2_moe", "qwen2_moe",
"qwen3_moe", "qwen3_moe",
]: ]:
setattr(config, "output_router_logits", is_trainable) setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
......
...@@ -97,7 +97,7 @@ def configure_quantization( ...@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
...@@ -111,12 +111,12 @@ def configure_quantization( ...@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True) check_version("optimum>=1.24.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
...@@ -142,7 +142,8 @@ def configure_quantization( ...@@ -142,7 +142,8 @@ def configure_quantization(
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB: if model_args.quantization_method == QuantizationMethod.BNB:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment