Commit 24534501 authored by mashun1's avatar mashun1
Browse files

parallel_tool

parent c4ba4563
......@@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [
"python3 -m sglang.launch_server",
......@@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
......@@ -147,7 +160,6 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
......@@ -200,6 +212,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params,
"stream": True,
}
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
......
......@@ -124,7 +124,6 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
......
......@@ -73,7 +73,7 @@ def main():
"help": partial(print, USAGE),
}
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
......
......@@ -169,11 +169,22 @@ def read_cloud_json(cloud_path):
try:
# Try with anonymous access first
fs = setup_fs(cloud_path, anon=True)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
except Exception:
# Try again with credentials
fs = setup_fs(cloud_path)
return _read_json_with_fs(fs, cloud_path, lines=cloud_path.endswith(".jsonl"))
if fs.isdir(cloud_path):
files = [x["Key"] for x in fs.listdir(cloud_path)]
else:
files = [cloud_path]
# filter out non-JSON files
files = [file for file in files if file.endswith(".json") or file.endswith(".jsonl")]
if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}")
data = []
for file in files:
data.extend(_read_json_with_fs(fs, file, lines=file.endswith(".jsonl")))
return data
def _read_json_with_fs(fs, path, lines=True):
......
......@@ -168,7 +168,7 @@ def _get_merged_dataset(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
merge: bool = True,
return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
......@@ -181,10 +181,10 @@ def _get_merged_dataset(
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if merge:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
else:
if return_dict:
return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor(
......@@ -303,7 +303,12 @@ def get_dataset(
with training_args.main_process_first(desc="load dataset"):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, stage, merge=training_args.do_predict
data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
)
with training_args.main_process_first(desc="pre-process dataset"):
......
This diff is collapsed.
......@@ -115,12 +115,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_list: list[DatasetAttr] = []
for name in dataset_names:
if dataset_info is None: # dataset_dir is ONLINE
if use_modelscope():
load_from = "ms_hub"
elif use_openmind():
load_from = "om_hub"
else:
load_from = "hf_hub"
load_from = "ms_hub" if use_modelscope() else "om_hub" if use_openmind() else "hf_hub"
dataset_attr = DatasetAttr(load_from, dataset_name=name)
dataset_list.append(dataset_attr)
continue
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
......@@ -51,6 +52,7 @@ class Template:
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
enable_thinking: Optional[bool]
mm_plugin: "BasePlugin"
def encode_oneturn(
......@@ -61,7 +63,7 @@ class Template:
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=True)
encoded_messages = self._encode(tokenizer, messages, system, tools)
prompt_ids = []
for encoded_ids in encoded_messages[:-1]:
prompt_ids += encoded_ids
......@@ -77,7 +79,7 @@ class Template:
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools, remove_thought=False)
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
......@@ -92,6 +94,19 @@ class Template:
return list(stop_token_ids)
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
def remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
token_ids = []
......@@ -111,18 +126,12 @@ class Template:
return token_ids
def _remove_thought(self, content: str) -> str:
r"""Remove thought from assistant message."""
pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
return re.sub(pattern, "", content).lstrip("\n")
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str],
tools: Optional[str],
remove_thought: bool,
) -> list[list[int]]:
r"""Encode formatted inputs to pairs of token ids.
......@@ -140,18 +149,14 @@ class Template:
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=content, idx=str(i // 2))
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
......@@ -162,6 +167,9 @@ class Template:
@staticmethod
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
r"""Add or replace eos token to the tokenizer."""
if tokenizer.eos_token == eos_token:
return
is_added = tokenizer.eos_token_id is None
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
......@@ -328,7 +336,6 @@ class Llama2Template(Template):
messages: list[dict[str, str]],
system: str,
tools: str,
remove_thought: bool,
) -> list[list[int]]:
system = system or self.default_system
encoded_messages = []
......@@ -342,18 +349,14 @@ class Llama2Template(Template):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0]
content = message["content"]
if remove_thought and message["role"] == Role.ASSISTANT and (i != len(messages) - 1):
content = self._remove_thought(content)
if message["role"] == Role.USER:
elements += self.format_user.apply(content=system_text + content)
elements += self.format_user.apply(content=system_text + message["content"])
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=content)
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=content)
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
elements += self.format_function.apply(content=content)
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
......@@ -392,6 +395,64 @@ class Llama2Template(Template):
return jinja_template
@dataclass
class ReasoningTemplate(Template):
r"""A template that add thought to assistant message."""
@override
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
for i in range(1, len(messages) - 2, 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
if self.enable_thinking is False: # remove all cot
messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
if (
self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
prompt_ids += self.get_thought_word_ids(tokenizer)
else: # do compute loss
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
return prompt_ids, response_ids
@override
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2):
if (
self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
): # add empty cot
if not self.enable_thinking: # do not compute loss
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else: # do compute loss
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
TEMPLATES: dict[str, "Template"] = {}
......@@ -410,6 +471,7 @@ def register_template(
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
enable_thinking: Optional[bool] = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
) -> None:
......@@ -456,6 +518,7 @@ def register_template(
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin,
)
......@@ -492,6 +555,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system):
......@@ -501,7 +565,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = ""
return Template(
return template_class(
format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]),
......@@ -515,6 +579,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"),
)
......@@ -543,6 +608,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
if data_args.default_system is not None:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
return template
......@@ -756,6 +826,7 @@ register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
......@@ -774,6 +845,15 @@ register_template(
)
# copied from deepseek3 template
register_template(
name="deepseekr1",
format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=ReasoningTemplate,
)
register_template(
name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
......@@ -838,6 +918,7 @@ register_template(
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
replace_eos=True,
template_class=Llama2Template,
)
......@@ -853,6 +934,7 @@ register_template(
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template,
)
......@@ -872,6 +954,22 @@ register_template(
)
# copied from glm4 template
register_template(
name="glmz1",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True,
template_class=ReasoningTemplate,
)
register_template(
name="granite3",
format_user=StringFormatter(
......@@ -1018,6 +1116,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
)
......@@ -1037,6 +1136,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
)
......@@ -1066,6 +1166,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
......@@ -1079,6 +1180,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant provided by Moonshot-AI.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
......@@ -1131,6 +1233,7 @@ register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
......@@ -1163,6 +1266,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
......@@ -1233,6 +1337,24 @@ register_template(
)
# copied from qwen template
register_template(
name="mimo",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
# copied from chatml template
register_template(
name="minicpm_v",
......@@ -1363,6 +1485,7 @@ register_template(
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
template_class=Llama2Template,
)
......@@ -1374,6 +1497,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|end|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
stop_words=["<|end|>"],
replace_eos=True,
)
......@@ -1384,6 +1508,7 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
stop_words=["<|end|>"],
replace_eos=True,
)
......@@ -1395,6 +1520,7 @@ register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system<|im_sep|>{{content}}<|im_end|>"]),
stop_words=["<|im_end|>"],
replace_eos=True,
)
......@@ -1425,6 +1551,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
......@@ -1440,6 +1567,8 @@ register_template(
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
template_class=ReasoningTemplate,
)
......@@ -1451,6 +1580,7 @@ register_template(
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
)
......@@ -1468,6 +1598,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
),
......@@ -1486,6 +1617,7 @@ register_template(
format_tools=ToolFormatter(tool_format="qwen"),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
......@@ -1503,6 +1635,20 @@ register_template(
)
register_template(
name="seed_coder",
format_user=StringFormatter(
slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
),
format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
default_system=(
"You are an AI programming assistant, utilizing the Seed-Coder model, developed by ByteDance Seed, "
"and you only answer questions related to computer science. For politically sensitive questions, "
"security and privacy issues, and other non-computer science questions, you will refuse to answer.\n\n"
),
)
# copied from llama3 template
register_template(
name="skywork_o1",
......
......@@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils):
tool_text = ""
tool_names = []
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
param_text = ""
for name, param in tool["parameters"]["properties"].items():
required, enum, items = "", "", ""
......@@ -124,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return function_text
return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
@override
@staticmethod
......@@ -159,6 +156,7 @@ class GLM4ToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", "") if tool.get("type") == "function" else tool
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
......@@ -200,7 +198,7 @@ class Llama3ToolUtils(ToolUtils):
date = datetime.now().strftime("%d %b %Y")
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
......@@ -208,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
try:
tool = json.loads(content.strip())
tools = json.loads(content.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "parameters" not in tool:
tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
class MistralToolUtils(ToolUtils):
r"""Mistral v0.3 tool using template."""
......@@ -235,18 +232,16 @@ class MistralToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str:
wrapped_tools = []
for tool in tools:
wrapped_tools.append({"type": "function", "function": tool})
wrapped_tools.append(tool if tool.get("type") == "function" else {"type": "function", "function": tool})
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return "[" + ", ".join(function_texts) + "]"
return json.dumps(
[{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
)
@override
@staticmethod
......@@ -256,17 +251,11 @@ class MistralToolUtils(ToolUtils):
except json.JSONDecodeError:
return content
if not isinstance(tools, list):
tools = [tools]
results = []
for tool in tools:
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
tools = [tools] if not isinstance(tools, list) else tools
try:
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
except KeyError:
return content
class QwenToolUtils(ToolUtils):
......@@ -277,7 +266,7 @@ class QwenToolUtils(ToolUtils):
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
......@@ -285,13 +274,11 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return "\n".join(function_texts)
function_texts = [
json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
for name, arguments in functions
]
return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
@override
@staticmethod
......
......@@ -533,6 +533,17 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
},
"DeepSeek-V3-671B-0324-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-0324",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-0324",
},
},
template="deepseek3",
)
register_model_group(
models={
"DeepSeek-R1-1.5B-Distill": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
......@@ -566,7 +577,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
},
},
template="deepseek3",
template="deepseekr1",
)
......@@ -737,6 +748,13 @@ register_model_group(
DownloadSource.DEFAULT: "THUDM/GLM-4-32B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4-32B-0414",
},
},
template="glm4",
)
register_model_group(
models={
"GLM-Z1-9B-0414-Chat": {
DownloadSource.DEFAULT: "THUDM/GLM-Z1-9B-0414",
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-9B-0414",
......@@ -746,7 +764,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-Z1-32B-0414",
},
},
template="glm4",
template="glmz1",
)
......@@ -869,12 +887,13 @@ register_model_group(
register_model_group(
models={
"Granite-3.2-1B-A400M-Base": {
"Granite-Vision-3.2-2B": {
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
},
},
template="granite3_vision",
multimodal=True,
)
......@@ -1398,6 +1417,29 @@ register_model_group(
)
register_model_group(
models={
"MiMo-7B-Base": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-Base",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-Base",
},
"MiMo-7B-Instruct": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-SFT",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-SFT",
},
"MiMo-7B-Instruct-RL": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL",
},
"MiMo-7B-RL-ZERO": {
DownloadSource.DEFAULT: "XiaomiMiMo/MiMo-7B-RL-ZERO",
DownloadSource.MODELSCOPE: "XiaomiMiMo/MiMo-7B-RL-ZERO",
},
},
template="mimo",
)
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
......@@ -2461,6 +2503,38 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B",
},
"Qwen3-0.6B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-0.6B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-0.6B-GPTQ-Int8",
},
"Qwen3-1.7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen3-1.7B-GPTQ-Int8",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-1.7B-GPTQ-Int8",
},
"Qwen3-4B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-4B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-4B-AWQ",
},
"Qwen3-8B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-8B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-8B-AWQ",
},
"Qwen3-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-14B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-14B-AWQ",
},
"Qwen3-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen3-32B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-32B-AWQ",
},
"Qwen3-30B-A3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-30B-A3B-GPTQ-Int4",
},
"Qwen3-235B-A22B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-235B-A22B-GPTQ-Int4",
},
},
template="qwen3",
)
......@@ -2484,10 +2558,22 @@ register_model_group(
register_model_group(
models={
"Qwen2.5-Omni-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-3B",
},
"Qwen2.5-Omni-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
}
},
"Qwen2.5-Omni-7B-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
},
"Qwen2.5-Omni-7B-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
},
},
template="qwen2_omni",
multimodal=True,
......@@ -2598,15 +2684,17 @@ register_model_group(
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
"Seed-Coder-8B-Base": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Base",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
"Seed-Coder-8B-Instruct": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Instruct",
},
"Seed-Coder-8B-Instruct-Reasoning": {
DownloadSource.DEFAULT: "ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16",
},
},
template="solar",
template="seed_coder",
)
......@@ -2631,6 +2719,20 @@ register_model_group(
)
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
},
},
template="solar",
)
register_model_group(
models={
"StarCoder2-3B": {
......
......@@ -79,20 +79,27 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
hint = f"To fix: run `{pip_command}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.5.0")
check_version("accelerate>=0.34.0,<=1.6.0")
check_version("peft>=0.14.0,<=0.15.1")
check_version(
"transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
......
......@@ -99,6 +99,10 @@ class DataArguments:
default=0.0,
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
)
eval_on_each_dataset: bool = field(
default=False,
metadata={"help": "Whether or not to evaluate on each dataset separately."},
)
packing: Optional[bool] = field(
default=None,
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
......@@ -111,6 +115,14 @@ class DataArguments:
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field(
default=None,
metadata={
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import GenerationConfig
......@@ -62,10 +62,6 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
......
......@@ -235,10 +235,6 @@ class ProcessorArguments:
default=False,
metadata={"help": "Whether to crop the image to patches for internvl."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
video_max_pixels: int = field(
default=256 * 256,
metadata={"help": "The maximum number of pixels of video inputs."},
......@@ -255,6 +251,10 @@ class ProcessorArguments:
default=128,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},
)
audio_sampling_rate: int = field(
default=16000,
metadata={"help": "The sampling rate of audio inputs."},
......@@ -364,6 +364,12 @@ class SGLangArguments:
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
......
......@@ -148,10 +148,10 @@ def _check_extra_dependencies(
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == EngineName.VLLM:
check_version("vllm>=0.4.3,<=0.8.4")
check_version("vllm>=0.4.3,<=0.8.6")
check_version("vllm", mandatory=True)
elif model_args.infer_backend == EngineName.SGLANG:
check_version("sglang>=0.4.4")
check_version("sglang>=0.4.5")
check_version("sglang", mandatory=True)
if finetuning_args.use_galore:
......
......@@ -64,6 +64,7 @@ class RayArguments:
raise ValueError(
f"ray_storage_filesystem must be one of ['s3', 'gs', 'gcs'], got {self.ray_storage_filesystem}"
)
import pyarrow.fs as fs
if self.ray_storage_filesystem == "s3":
......
......@@ -29,10 +29,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if getattr(config, "model_type", None) == "gemma2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
if model_args.flash_attn != AttentionFunction.FA2:
......
......@@ -45,16 +45,24 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "glm4":
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "granite":
from liger_kernel.transformers import apply_liger_kernel_to_granite as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "llava":
from liger_kernel.transformers import apply_liger_kernel_to_llava as apply_liger_kernel
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "mllama":
from liger_kernel.transformers import apply_liger_kernel_to_mllama as apply_liger_kernel
elif model_type == "olmo2":
from liger_kernel.transformers import apply_liger_kernel_to_olmo2 as apply_liger_kernel
elif model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
......@@ -63,6 +71,8 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
elif model_type == "qwen3":
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
else:
logger.warning_rank0("Current model does not support liger kernel.")
return
......
......@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if model_type in [
"dbrx",
"granitemoe",
"jamba",
"jetmoe",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", is_trainable)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if model_type in [
"dbrx",
"granitemoe",
"jamba",
"jetmoe",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
......@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
......@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("optimum>=1.24.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
......@@ -142,7 +142,8 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment