"vscode:/vscode.git/clone" did not exist on "a87b33db263049991999209f986364a3445c40db"
Commit 27a7ad86 authored by luopl's avatar luopl
Browse files

update to v0.9.1

parent 731cf9b8
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
from ..mm_plugin import ImageInput, VideoInput
from ..template import Template from ..template import Template
...@@ -34,27 +36,24 @@ def _encode_unsupervised_example( ...@@ -34,27 +36,24 @@ def _encode_unsupervised_example(
response: Sequence[Dict[str, str]], response: Sequence[Dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}] messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
messages = template.mm_plugin.process_messages(messages, images, videos, processor)
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools) input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
...@@ -67,24 +66,21 @@ def preprocess_unsupervised_dataset( ...@@ -67,24 +66,21 @@ def preprocess_unsupervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None: for i in range(len(examples["_prompt"])):
model_inputs["pixel_values"] = [] if len(examples["_prompt"][i]) % 2 != 1:
if hasattr(processor, "image_seq_length"): # paligemma models logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i], prompt=examples["_prompt"][i],
response=examples["response"][i], response=examples["_response"][i],
system=examples["system"][i], system=examples["_system"][i],
tools=examples["tools"][i], tools=examples["_tools"][i],
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
...@@ -93,10 +89,8 @@ def preprocess_unsupervised_dataset( ...@@ -93,10 +89,8 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: model_inputs["images"].append(examples["_images"][i])
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs["videos"].append(examples["_videos"][i])
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs
......
...@@ -15,15 +15,21 @@ ...@@ -15,15 +15,21 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ..hparams import DataArguments
from .formatter import SLOTS, Formatter from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -41,9 +47,10 @@ class Template: ...@@ -41,9 +47,10 @@ class Template:
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
image_token: str
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool
mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
self, self,
...@@ -147,6 +154,7 @@ class Template: ...@@ -147,6 +154,7 @@ class Template:
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
@override
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
...@@ -190,7 +198,7 @@ class Llama2Template(Template): ...@@ -190,7 +198,7 @@ class Llama2Template(Template):
return encoded_messages return encoded_messages
TEMPLATES: Dict[str, Template] = {} TEMPLATES: Dict[str, "Template"] = {}
def _register_template( def _register_template(
...@@ -205,9 +213,10 @@ def _register_template( ...@@ -205,9 +213,10 @@ def _register_template(
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
image_token: str = "<image>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
...@@ -254,9 +263,10 @@ def _register_template( ...@@ -254,9 +263,10 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter, format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
image_token=image_token,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
mm_plugin=mm_plugin,
) )
...@@ -300,6 +310,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl ...@@ -300,6 +310,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = "" jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
...@@ -339,23 +352,29 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") ...@@ -339,23 +352,29 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
return jinja_template return jinja_template
def get_template_and_fix_tokenizer( def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
tokenizer: "PreTrainedTokenizer", r"""
name: Optional[str] = None, Gets chat template and fixes the tokenizer.
tool_format: Optional[str] = None, """
) -> Template: if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
if name is None: require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
template = TEMPLATES.get(name, None) template = TEMPLATES.get(data_args.template, None)
if template is None: if template is None:
raise ValueError("Template {} does not exist.".format(name)) raise ValueError("Template {} does not exist.".format(data_args.template))
if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
if tool_format is not None: if data_args.tool_format is not None:
logger.info("Using tool format: {}.".format(tool_format)) logger.info("Using tool format: {}.".format(data_args.tool_format))
eos_slots = [] if template.efficient_eos else [{"eos_token"}] eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_tools = ToolFormatter(tool_format=tool_format) template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:
...@@ -380,10 +399,11 @@ def get_template_and_fix_tokenizer( ...@@ -380,10 +399,11 @@ def get_template_and_fix_tokenizer(
if num_added_tokens > 0: if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.") logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try: if template.replace_jinja_template:
tokenizer.chat_template = _get_jinja_template(template, tokenizer) try:
except ValueError: tokenizer.chat_template = _get_jinja_template(template, tokenizer)
logger.info("Cannot add this chat template to tokenizer.") except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
return template return template
...@@ -550,6 +570,15 @@ _register_template( ...@@ -550,6 +570,15 @@ _register_template(
) )
_register_template(
name="cpm3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
_register_template( _register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
...@@ -613,6 +642,14 @@ _register_template( ...@@ -613,6 +642,14 @@ _register_template(
) )
_register_template(
name="exaone",
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
)
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
...@@ -637,6 +674,7 @@ _register_template( ...@@ -637,6 +674,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
replace_jinja_template=False,
) )
...@@ -713,6 +751,119 @@ _register_template( ...@@ -713,6 +751,119 @@ _register_template(
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
)
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template(
name="llava_next",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_llama3",
format_user=StringFormatter(
slots=[
(
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_observation=StringFormatter(
slots=[
(
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
_register_template(
name="llava_next_video",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
)
_register_template(
name="llava_next_video_yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
) )
...@@ -760,6 +911,19 @@ _register_template( ...@@ -760,6 +911,19 @@ _register_template(
) )
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
...@@ -780,6 +944,21 @@ _register_template( ...@@ -780,6 +944,21 @@ _register_template(
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
replace_jinja_template=False,
)
_register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
) )
...@@ -834,6 +1013,17 @@ _register_template( ...@@ -834,6 +1013,17 @@ _register_template(
) )
_register_template(
name="video_llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
)
_register_template( _register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
...@@ -894,6 +1084,7 @@ _register_template( ...@@ -894,6 +1084,7 @@ _register_template(
), ),
stop_words=["###"], stop_words=["###"],
efficient_eos=True, efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
) )
......
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
...@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = ( ...@@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
) )
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
"""
Base class for tool utilities.
"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_function_slots() -> SLOTS: ... def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"] return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
...@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils): ...@@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content) action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
...@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils): ...@@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"] return ["{{name}}\n{{arguments}}"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
...@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils): ...@@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text) return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content: if "\n" not in content:
return content return content
...@@ -138,3 +166,17 @@ class GLM4ToolUtils(ToolUtils): ...@@ -138,3 +166,17 @@ class GLM4ToolUtils(ToolUtils):
return content return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))] return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
}
def get_tool_utils(name: str) -> "ToolUtils":
tool_utils = TOOLS.get(name, None)
if tool_utils is None:
raise ValueError("Tool utils `{}` not found.".format(name))
return tool_utils
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
import json import json
import os import os
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -54,18 +54,22 @@ from ..model import load_model, load_tokenizer ...@@ -54,18 +54,22 @@ from ..model import load_model, load_tokenizer
from .template import get_eval_template from .template import get_eval_template
if TYPE_CHECKING:
from numpy.typing import NDArray
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
logits = self.model(**batch_input).logits logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1) lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
...@@ -132,7 +136,7 @@ class Evaluator: ...@@ -132,7 +136,7 @@ class Evaluator:
pbar.close() pbar.close()
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
......
...@@ -47,6 +47,8 @@ FILEEXT2TYPE = { ...@@ -47,6 +47,8 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>"
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml" LLAMABOARD_CONFIG = "llamaboard_config.yaml"
...@@ -93,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = { ...@@ -93,6 +95,8 @@ SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
VIDEO_PLACEHOLDER = "<video>"
V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
...@@ -110,17 +114,12 @@ def register_model_group( ...@@ -110,17 +114,12 @@ def register_model_group(
template: Optional[str] = None, template: Optional[str] = None,
vision: bool = False, vision: bool = False,
) -> None: ) -> None:
prefix = None
for name, path in models.items(): for name, path in models.items():
if prefix is None:
prefix = name.split("-")[0]
else:
assert prefix == name.split("-")[0], "prefix should be identical."
SUPPORTED_MODELS[name] = path SUPPORTED_MODELS[name] = path
if template is not None: if template is not None and any(suffix in name for suffix in ("-Chat", "-Instruct")):
DEFAULT_TEMPLATE[prefix] = template DEFAULT_TEMPLATE[name] = template
if vision: if vision:
VISION_MODELS.add(prefix) VISION_MODELS.add(name)
register_model_group( register_model_group(
...@@ -234,7 +233,7 @@ register_model_group( ...@@ -234,7 +233,7 @@ register_model_group(
"Breeze-7B": { "Breeze-7B": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0", DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
}, },
"Breeze-7B-Chat": { "Breeze-7B-Instruct": {
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0", DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
}, },
}, },
...@@ -270,27 +269,27 @@ register_model_group( ...@@ -270,27 +269,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"ChineseLLaMA2-1.3B": { "Chinese-Llama-2-1.3B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
}, },
"ChineseLLaMA2-7B": { "Chinese-Llama-2-7B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
}, },
"ChineseLLaMA2-13B": { "Chinese-Llama-2-13B": {
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
}, },
"ChineseLLaMA2-1.3B-Chat": { "Chinese-Alpaca-2-1.3B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
}, },
"ChineseLLaMA2-7B-Chat": { "Chinese-Alpaca-2-7B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
}, },
"ChineseLLaMA2-13B-Chat": { "Chinese-Alpaca-2-13B-Chat": {
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b", DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
}, },
...@@ -315,14 +314,14 @@ register_model_group( ...@@ -315,14 +314,14 @@ register_model_group(
"CodeGemma-7B": { "CodeGemma-7B": {
DownloadSource.DEFAULT: "google/codegemma-7b", DownloadSource.DEFAULT: "google/codegemma-7b",
}, },
"CodeGemma-7B-Chat": { "CodeGemma-7B-Instruct": {
DownloadSource.DEFAULT: "google/codegemma-7b-it", DownloadSource.DEFAULT: "google/codegemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it", DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
}, },
"CodeGemma-1.1-2B": { "CodeGemma-1.1-2B": {
DownloadSource.DEFAULT: "google/codegemma-1.1-2b", DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
}, },
"CodeGemma-1.1-7B-Chat": { "CodeGemma-1.1-7B-Instruct": {
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it", DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
}, },
}, },
...@@ -368,7 +367,7 @@ register_model_group( ...@@ -368,7 +367,7 @@ register_model_group(
DownloadSource.DEFAULT: "databricks/dbrx-base", DownloadSource.DEFAULT: "databricks/dbrx-base",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base", DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
}, },
"DBRX-132B-Chat": { "DBRX-132B-Instruct": {
DownloadSource.DEFAULT: "databricks/dbrx-instruct", DownloadSource.DEFAULT: "databricks/dbrx-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct", DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
}, },
...@@ -399,7 +398,7 @@ register_model_group( ...@@ -399,7 +398,7 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
}, },
"DeepSeek-Math-7B-Chat": { "DeepSeek-Math-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
}, },
...@@ -407,36 +406,36 @@ register_model_group( ...@@ -407,36 +406,36 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
}, },
"DeepSeek-MoE-16B-v2-Base": { "DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
"DeepSeek-V2-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
}, },
"DeepSeek-MoE-236B-Base": { "DeepSeek-V2-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
}, },
"DeepSeek-MoE-16B-Chat": { "DeepSeek-V2-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
"DeepSeek-MoE-16B-v2-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
}, },
"DeepSeek-MoE-236B-Chat": { "DeepSeek-V2-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
}, },
"DeepSeek-MoE-Coder-16B-Base": { "DeepSeek-Coder-V2-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
}, },
"DeepSeek-MoE-Coder-236B-Base": { "DeepSeek-Coder-V2-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
}, },
"DeepSeek-MoE-Coder-16B-Chat": { "DeepSeek-Coder-V2-16B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
}, },
"DeepSeek-MoE-Coder-236B-Chat": { "DeepSeek-Coder-V2-236B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
}, },
}, },
...@@ -446,25 +445,25 @@ register_model_group( ...@@ -446,25 +445,25 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"DeepSeekCoder-6.7B-Base": { "DeepSeek-Coder-6.7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
}, },
"DeepSeekCoder-7B-Base": { "DeepSeek-Coder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
}, },
"DeepSeekCoder-33B-Base": { "DeepSeek-Coder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
}, },
"DeepSeekCoder-6.7B-Chat": { "DeepSeek-Coder-6.7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
}, },
"DeepSeekCoder-7B-Chat": { "DeepSeek-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
}, },
"DeepSeekCoder-33B-Chat": { "DeepSeek-Coder-33B-Instruct": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
}, },
...@@ -473,6 +472,16 @@ register_model_group( ...@@ -473,6 +472,16 @@ register_model_group(
) )
register_model_group(
models={
"EXAONE-3.0-7.8B-Instruct": {
DownloadSource.DEFAULT: "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
},
},
template="exaone",
)
register_model_group( register_model_group(
models={ models={
"Falcon-7B": { "Falcon-7B": {
...@@ -490,11 +499,11 @@ register_model_group( ...@@ -490,11 +499,11 @@ register_model_group(
DownloadSource.DEFAULT: "tiiuae/falcon-180b", DownloadSource.DEFAULT: "tiiuae/falcon-180b",
DownloadSource.MODELSCOPE: "modelscope/falcon-180B", DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
}, },
"Falcon-7B-Chat": { "Falcon-7B-Instruct": {
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct", DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
}, },
"Falcon-40B-Chat": { "Falcon-40B-Instruct": {
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct", DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
}, },
...@@ -517,18 +526,18 @@ register_model_group( ...@@ -517,18 +526,18 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-7b", DownloadSource.DEFAULT: "google/gemma-7b",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it", DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
}, },
"Gemma-2B-Chat": { "Gemma-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2b-it", DownloadSource.DEFAULT: "google/gemma-2b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b", DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
}, },
"Gemma-7B-Chat": { "Gemma-7B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-7b-it", DownloadSource.DEFAULT: "google/gemma-7b-it",
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it", DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
}, },
"Gemma-1.1-2B-Chat": { "Gemma-1.1-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it", DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
}, },
"Gemma-1.1-7B-Chat": { "Gemma-1.1-7B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
}, },
"Gemma-2-2B": { "Gemma-2-2B": {
...@@ -543,15 +552,15 @@ register_model_group( ...@@ -543,15 +552,15 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b", DownloadSource.DEFAULT: "google/gemma-2-27b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
}, },
"Gemma-2-2B-Chat": { "Gemma-2-2B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-2b-it", DownloadSource.DEFAULT: "google/gemma-2-2b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-2b-it",
}, },
"Gemma-2-9B-Chat": { "Gemma-2-9B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
}, },
"Gemma-2-27B-Chat": { "Gemma-2-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
}, },
...@@ -620,17 +629,22 @@ register_model_group( ...@@ -620,17 +629,22 @@ register_model_group(
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
}, },
}, "InternLM2.5-1.8B": {
template="intern2", DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b",
) DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b",
},
register_model_group(
models={
"InternLM2.5-7B": { "InternLM2.5-7B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b", DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
}, },
"InternLM2.5-20B": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b",
},
"InternLM2.5-1.8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-1_8b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-1_8b-chat",
},
"InternLM2.5-7B-Chat": { "InternLM2.5-7B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
...@@ -639,6 +653,10 @@ register_model_group( ...@@ -639,6 +653,10 @@ register_model_group(
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m", DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
}, },
"InternLM2.5-20B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
},
}, },
template="intern2", template="intern2",
) )
...@@ -666,19 +684,19 @@ register_model_group( ...@@ -666,19 +684,19 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA-7B": { "Llama-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
}, },
"LLaMA-13B": { "Llama-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b", DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b", DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
}, },
"LLaMA-30B": { "Llama-30B": {
DownloadSource.DEFAULT: "huggyllama/llama-30b", DownloadSource.DEFAULT: "huggyllama/llama-30b",
DownloadSource.MODELSCOPE: "skyline2006/llama-30b", DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
}, },
"LLaMA-65B": { "Llama-65B": {
DownloadSource.DEFAULT: "huggyllama/llama-65b", DownloadSource.DEFAULT: "huggyllama/llama-65b",
DownloadSource.MODELSCOPE: "skyline2006/llama-65b", DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
}, },
...@@ -688,27 +706,27 @@ register_model_group( ...@@ -688,27 +706,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA2-7B": { "Llama-2-7B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
}, },
"LLaMA2-13B": { "Llama-2-13B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
}, },
"LLaMA2-70B": { "Llama-2-70B": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
}, },
"LLaMA2-7B-Chat": { "Llama-2-7B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
}, },
"LLaMA2-13B-Chat": { "Llama-2-13B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
}, },
"LLaMA2-70B-Chat": { "Llama-2-70B-Chat": {
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms", DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
}, },
...@@ -719,57 +737,76 @@ register_model_group( ...@@ -719,57 +737,76 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA3-8B": { "Llama-3-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
}, },
"LLaMA3-70B": { "Llama-3-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
}, },
"LLaMA3-8B-Chat": { "Llama-3-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
}, },
"LLaMA3-70B-Chat": { "Llama-3-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
}, },
"LLaMA3-8B-Chinese-Chat": { "Llama-3-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
}, },
"LLaMA3-70B-Chinese-Chat": { "Llama-3-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
}, },
}, "Llama-3.1-8B": {
template="llama3",
)
register_model_group(
models={
"LLaMA3.1-8B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B",
}, },
"LLaMA3.1-70B": { "Llama-3.1-70B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B",
}, },
"LLaMA3.1-405B": { "Llama-3.1-405B": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B",
}, },
"LLaMA3.1-8B-Chat": { "Llama-3.1-8B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-8B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-8B-Instruct",
}, },
"LLaMA3.1-70B-Chat": { "Llama-3.1-70B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-70B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-70B-Instruct",
}, },
"LLaMA3.1-405B-Chat": { "Llama-3.1-405B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct", DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3.1-405B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3.1-405B-Instruct",
},
"Llama-3.1-8B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-8B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-8B-Chinese-Chat",
},
"Llama-3.1-70B-Chinese-Chat": {
DownloadSource.DEFAULT: "shenzhi-wang/Llama3.1-70B-Chinese-Chat",
DownloadSource.MODELSCOPE: "XD_AI/Llama3.1-70B-Chinese-Chat",
},
"Llama-3.2-1B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B",
},
"Llama-3.2-3B": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B",
},
"Llama-3.2-1B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-1B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-1B-Instruct",
},
"Llama-3.2-3B-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
}, },
}, },
template="llama3", template="llama3",
...@@ -778,14 +815,127 @@ register_model_group( ...@@ -778,14 +815,127 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaVA1.5-7B-Chat": { "LLaVA-1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-7b-hf",
}, },
"LLaVA1.5-13B-Chat": { "LLaVA-1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-1.5-13b-hf",
}, },
}, },
template="vicuna", template="llava",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-7b-hf",
},
"LLaVA-NeXT-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-vicuna-13b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-vicuna-13b-hf",
},
},
template="llava_next",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Mistral-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-mistral-7b-hf",
DownloadSource.MODELSCOPE: "swift/llava-v1.6-mistral-7b-hf",
},
},
template="llava_next_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Llama3-8B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llama3-llava-next-8b-hf",
DownloadSource.MODELSCOPE: "swift/llama3-llava-next-8b-hf",
},
},
template="llava_next_llama3",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-v1.6-34b-hf",
DownloadSource.MODELSCOPE: "LLM-Research/llava-v1.6-34b-hf",
},
},
template="llava_next_yi",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-72B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-72b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-72b-hf",
},
"LLaVA-NeXT-110B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-next-110b-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/llava-next-110b-hf",
},
},
template="llava_next_qwen",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-hf",
},
"LLaVA-NeXT-Video-7B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-DPO-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-DPO-hf",
},
},
template="llava_next_video",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-7B-32k-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-7B-32K-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-7B-32K-hf",
},
},
template="llava_next_video_mistral",
vision=True,
)
register_model_group(
models={
"LLaVA-NeXT-Video-34B-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-hf",
DownloadSource.MODELSCOPE: "swift/LLaVA-NeXT-Video-34B-hf",
},
"LLaVA-NeXT-Video-34B-DPO-Chat": {
DownloadSource.DEFAULT: "llava-hf/LLaVA-NeXT-Video-34B-DPO-hf",
},
},
template="llava_next_video_yi",
vision=True, vision=True,
) )
...@@ -805,13 +955,24 @@ register_model_group( ...@@ -805,13 +955,24 @@ register_model_group(
) )
register_model_group(
models={
"MiniCPM3-4B-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B",
},
},
template="cpm3",
)
register_model_group( register_model_group(
models={ models={
"Mistral-7B-v0.1": { "Mistral-7B-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
}, },
"Mistral-7B-v0.1-Chat": { "Mistral-7B-Instruct-v0.1": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
}, },
...@@ -819,18 +980,18 @@ register_model_group( ...@@ -819,18 +980,18 @@ register_model_group(
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf", DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
}, },
"Mistral-7B-v0.2-Chat": { "Mistral-7B-Instruct-v0.2": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
}, },
"Mistral-7B-v0.3": { "Mistral-7B-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3", DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
}, },
"Mistral-7B-v0.3-Chat": { "Mistral-7B-Instruct-v0.3": {
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3", DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3", DownloadSource.MODELSCOPE: "LLM-Research/Mistral-7B-Instruct-v0.3",
}, },
"Mistral-Nemo-Chat": { "Mistral-Nemo-Instruct-2407": {
DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407", DownloadSource.DEFAULT: "mistralai/Mistral-Nemo-Instruct-2407",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-Nemo-Instruct-2407",
}, },
...@@ -845,7 +1006,7 @@ register_model_group( ...@@ -845,7 +1006,7 @@ register_model_group(
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
}, },
"Mixtral-8x7B-v0.1-Chat": { "Mixtral-8x7B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
}, },
...@@ -853,7 +1014,7 @@ register_model_group( ...@@ -853,7 +1014,7 @@ register_model_group(
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
}, },
"Mixtral-8x22B-v0.1-Chat": { "Mixtral-8x22B-v0.1-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1", DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
}, },
...@@ -930,27 +1091,28 @@ register_model_group( ...@@ -930,27 +1091,28 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"PaliGemma-3B-pt-224": { "PaliGemma-3B-pt-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
}, },
"PaliGemma-3B-pt-448": { "PaliGemma-3B-pt-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
}, },
"PaliGemma-3B-pt-896": { "PaliGemma-3B-pt-896-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
}, },
"PaliGemma-3B-mix-224": { "PaliGemma-3B-mix-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
}, },
"PaliGemma-3B-mix-448": { "PaliGemma-3B-mix-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
}, },
}, },
template="paligemma",
vision=True, vision=True,
) )
...@@ -971,27 +1133,27 @@ register_model_group( ...@@ -971,27 +1133,27 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Phi3-4B-4k-Chat": { "Phi-3-4B-4k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
}, },
"Phi3-4B-128k-Chat": { "Phi-3-4B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
}, },
"Phi3-7B-8k-Chat": { "Phi-3-7B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
}, },
"Phi3-7B-128k-Chat": { "Phi-3-7B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
}, },
"Phi3-14B-8k-Chat": { "Phi-3-14B-8k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
}, },
"Phi3-14B-128k-Chat": { "Phi-3-14B-128k-Instruct": {
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct", DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
}, },
...@@ -1034,35 +1196,35 @@ register_model_group( ...@@ -1034,35 +1196,35 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
}, },
"Qwen-1.8B-int8-Chat": { "Qwen-1.8B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
}, },
"Qwen-1.8B-int4-Chat": { "Qwen-1.8B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
}, },
"Qwen-7B-int8-Chat": { "Qwen-7B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
}, },
"Qwen-7B-int4-Chat": { "Qwen-7B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
}, },
"Qwen-14B-int8-Chat": { "Qwen-14B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
}, },
"Qwen-14B-int4-Chat": { "Qwen-14B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
}, },
"Qwen-72B-int8-Chat": { "Qwen-72B-Chat-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
}, },
"Qwen-72B-int4-Chat": { "Qwen-72B-Chat-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
}, },
...@@ -1109,10 +1271,6 @@ register_model_group( ...@@ -1109,10 +1271,6 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
}, },
"Qwen1.5-Code-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"Qwen1.5-0.5B-Chat": { "Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
...@@ -1149,71 +1307,75 @@ register_model_group( ...@@ -1149,71 +1307,75 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
}, },
"Qwen1.5-Code-7B-Chat": { "Qwen1.5-0.5B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-0.5B-int4-Chat": { "Qwen1.5-0.5B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
}, },
"Qwen1.5-1.8B-int8-Chat": { "Qwen1.5-1.8B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-1.8B-int4-Chat": { "Qwen1.5-1.8B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
}, },
"Qwen1.5-4B-int8-Chat": { "Qwen1.5-4B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-4B-int4-Chat": { "Qwen1.5-4B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
}, },
"Qwen1.5-7B-int8-Chat": { "Qwen1.5-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-7B-int4-Chat": { "Qwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
}, },
"Qwen1.5-14B-int8-Chat": { "Qwen1.5-14B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-14B-int4-Chat": { "Qwen1.5-14B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
}, },
"Qwen1.5-32B-int4-Chat": { "Qwen1.5-32B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
}, },
"Qwen1.5-72B-int8-Chat": { "Qwen1.5-72B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
}, },
"Qwen1.5-72B-int4-Chat": { "Qwen1.5-72B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
}, },
"Qwen1.5-110B-int4-Chat": { "Qwen1.5-110B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
}, },
"Qwen1.5-MoE-A2.7B-int4-Chat": { "Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
}, },
"Qwen1.5-Code-7B-int4-Chat": { "CodeQwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
},
"CodeQwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
},
"CodeQwen1.5-7B-Chat-AWQ": {
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ", DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ", DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
}, },
...@@ -1240,90 +1402,106 @@ register_model_group( ...@@ -1240,90 +1402,106 @@ register_model_group(
DownloadSource.DEFAULT: "Qwen/Qwen2-72B", DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B", DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
}, },
"Qwen2-MoE-57B": { "Qwen2-MoE-57B-A14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B", DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B", DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
}, },
"Qwen2-Math-1.5B": { "Qwen2-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B",
},
"Qwen2-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B",
},
"Qwen2-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B",
},
"Qwen2-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
}, },
"Qwen2-1.5B-Chat": { "Qwen2-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
}, },
"Qwen2-7B-Chat": { "Qwen2-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
}, },
"Qwen2-72B-Chat": { "Qwen2-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
}, },
"Qwen2-MoE-57B-Chat": { "Qwen2-MoE-57B-A14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct", DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
}, },
"Qwen2-Math-1.5B-Chat": { "Qwen2-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B-Instruct",
},
"Qwen2-Math-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B-Instruct",
},
"Qwen2-Math-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B-Instruct",
},
"Qwen2-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
}, },
"Qwen2-0.5B-int4-Chat": { "Qwen2-0.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int4",
},
"Qwen2-0.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
}, },
"Qwen2-1.5B-int8-Chat": { "Qwen2-1.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
}, },
"Qwen2-1.5B-int4-Chat": { "Qwen2-1.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
},
"Qwen2-1.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
}, },
"Qwen2-7B-int8-Chat": { "Qwen2-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
}, },
"Qwen2-7B-int4-Chat": { "Qwen2-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int4",
},
"Qwen2-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
}, },
"Qwen2-72B-int8-Chat": { "Qwen2-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8", DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
}, },
"Qwen2-72B-int4-Chat": { "Qwen2-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int4",
},
"Qwen2-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ", DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ", DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
}, },
"Qwen2-MoE-57B-int4-Chat": { "Qwen2-57B-A14B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
}, },
"Qwen2-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B",
},
"Qwen2-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B",
},
"Qwen2-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B",
},
"Qwen2-Math-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-1.5B-Instruct",
},
"Qwen2-Math-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-7B-Instruct",
},
"Qwen2-Math-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-Math-72B-Instruct",
},
}, },
template="qwen", template="qwen",
) )
...@@ -1331,10 +1509,253 @@ register_model_group( ...@@ -1331,10 +1509,253 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B": { "Qwen2.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B",
},
"Qwen2.5-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B",
},
"Qwen2.5-3B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B",
},
"Qwen2.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B",
},
"Qwen2.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B",
},
"Qwen2.5-32B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B",
},
"Qwen2.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B",
},
"Qwen2.5-0.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct",
},
"Qwen2.5-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct",
},
"Qwen2.5-3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct",
},
"Qwen2.5-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct",
},
"Qwen2.5-14B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct",
},
"Qwen2.5-32B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct",
},
"Qwen2.5-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8",
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4",
},
"Qwen2.5-0.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-0.5B-Instruct-AWQ",
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8",
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4",
},
"Qwen2.5-1.5B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-1.5B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-1.5B-Instruct-AWQ",
},
"Qwen2.5-3B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct-GPTQ-Int8",
},
"Qwen2.5-3B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct-GPTQ-Int4",
},
"Qwen2.5-3B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-3B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-3B-Instruct-AWQ",
},
"Qwen2.5-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct-GPTQ-Int8",
},
"Qwen2.5-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct-GPTQ-Int4",
},
"Qwen2.5-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-7B-Instruct-AWQ",
},
"Qwen2.5-14B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct-GPTQ-Int8",
},
"Qwen2.5-14B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct-GPTQ-Int4",
},
"Qwen2.5-14B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-14B-Instruct-AWQ",
},
"Qwen2.5-32B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct-GPTQ-Int8",
},
"Qwen2.5-32B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct-GPTQ-Int4",
},
"Qwen2.5-32B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-32B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-32B-Instruct-AWQ",
},
"Qwen2.5-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-GPTQ-Int8",
},
"Qwen2.5-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-GPTQ-Int4",
},
"Qwen2.5-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-72B-Instruct-AWQ",
},
"Qwen2.5-Coder-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B",
},
"Qwen2.5-Coder-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B",
},
"Qwen2.5-Coder-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Coder-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Coder-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Math-1.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-1.5B",
},
"Qwen2.5-Math-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-7B",
},
"Qwen2.5-Math-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Math-72B",
},
"Qwen2.5-Math-1.5B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-1.5B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-1.5B-Instruct",
},
"Qwen2.5-Math-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-7B-Instruct",
},
"Qwen2.5-Math-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Math-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2.5-Coder-72B-Instruct",
},
},
template="qwen",
)
register_model_group(
models={
"Qwen2-VL-2B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
},
"Qwen2-VL-7B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
},
"Qwen2-VL-72B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct",
},
"Qwen2-VL-2B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-2B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-2B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct-AWQ",
},
"Qwen2-VL-7B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-7B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-7B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct-AWQ",
},
"Qwen2-VL-72B-Instruct-GPTQ-Int8": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8",
},
"Qwen2-VL-72B-Instruct-GPTQ-Int4": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4",
},
"Qwen2-VL-72B-Instruct-AWQ": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B-Instruct-AWQ",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-72B-Instruct-AWQ",
},
},
template="qwen2_vl",
vision=True,
)
register_model_group(
models={
"SOLAR-10.7B-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
}, },
"SOLAR-10.7B-Chat": { "SOLAR-10.7B-Instruct-v1.0": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
}, },
...@@ -1396,11 +1817,11 @@ register_model_group( ...@@ -1396,11 +1817,11 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Vicuna1.5-7B-Chat": { "Vicuna-v1.5-7B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5", DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
}, },
"Vicuna1.5-13B-Chat": { "Vicuna-v1.5-13B-Chat": {
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5", DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
}, },
...@@ -1409,6 +1830,17 @@ register_model_group( ...@@ -1409,6 +1830,17 @@ register_model_group(
) )
register_model_group(
models={
"Video-LLaVA-7B-Chat": {
DownloadSource.DEFAULT: "LanguageBind/Video-LLaVA-7B-hf",
},
},
template="video_llava",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"XuanYuan-6B": { "XuanYuan-6B": {
...@@ -1419,7 +1851,7 @@ register_model_group( ...@@ -1419,7 +1851,7 @@ register_model_group(
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
}, },
"XuanYuan-2-70B": { "XuanYuan2-70B": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
}, },
...@@ -1431,31 +1863,31 @@ register_model_group( ...@@ -1431,31 +1863,31 @@ register_model_group(
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
}, },
"XuanYuan-2-70B-Chat": { "XuanYuan2-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
}, },
"XuanYuan-6B-int8-Chat": { "XuanYuan-6B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
}, },
"XuanYuan-6B-int4-Chat": { "XuanYuan-6B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
}, },
"XuanYuan-70B-int8-Chat": { "XuanYuan-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
}, },
"XuanYuan-70B-int4-Chat": { "XuanYuan-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
}, },
"XuanYuan-2-70B-int8-Chat": { "XuanYuan2-70B-Chat-8bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
}, },
"XuanYuan-2-70B-int4-Chat": { "XuanYuan2-70B-Chat-4bit": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
}, },
...@@ -1498,23 +1930,23 @@ register_model_group( ...@@ -1498,23 +1930,23 @@ register_model_group(
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B", DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B", DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
}, },
"XVERSE-7B-int8-Chat": { "XVERSE-7B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
}, },
"XVERSE-7B-int4-Chat": { "XVERSE-7B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
}, },
"XVERSE-13B-int8-Chat": { "XVERSE-13B-Chat-GPTQ-Int8": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8", DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
}, },
"XVERSE-13B-int4-Chat": { "XVERSE-13B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
}, },
"XVERSE-65B-int4-Chat": { "XVERSE-65B-Chat-GPTQ-Int4": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4", DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
}, },
...@@ -1560,19 +1992,19 @@ register_model_group( ...@@ -1560,19 +1992,19 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
}, },
"Yi-6B-int8-Chat": { "Yi-6B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
}, },
"Yi-6B-int4-Chat": { "Yi-6B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
}, },
"Yi-34B-int8-Chat": { "Yi-34B-Chat-8bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
}, },
"Yi-34B-int4-Chat": { "Yi-34B-Chat-4bits": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits", DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
}, },
...@@ -1600,6 +2032,22 @@ register_model_group( ...@@ -1600,6 +2032,22 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat", DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
}, },
"Yi-Coder-1.5B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B",
},
"Yi-Coder-9B": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B",
},
"Yi-Coder-1.5B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-1.5B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-1.5B-Chat",
},
"Yi-Coder-9B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-Coder-9B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-Coder-9B-Chat",
},
}, },
template="yi", template="yi",
) )
...@@ -1607,10 +2055,10 @@ register_model_group( ...@@ -1607,10 +2055,10 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"YiVL-6B-Chat": { "Yi-VL-6B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf", DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
}, },
"YiVL-34B-Chat": { "Yi-VL-34B-Chat": {
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf", DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
}, },
}, },
......
...@@ -26,7 +26,7 @@ import trl ...@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.8.4.dev0" VERSION = "0.9.1.dev0"
def print_env() -> None: def print_env() -> None:
......
# Copyright 2024 the LlamaFactory team. # Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,14 +18,21 @@ ...@@ -15,14 +18,21 @@
import logging import logging
import os import os
import sys import sys
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from .constants import RUNNING_LOG from .constants import RUNNING_LOG
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r""" r"""
Logger handler used in Web UI. Redirects the logging output to the logging file for LLaMA Board.
""" """
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
...@@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler): ...@@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
return super().close() return super().close()
def get_logger(name: str) -> logging.Logger: def _get_default_logging_level() -> "logging._Level":
r"""
Returns the default logging level.
"""
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
if env_level_str:
if env_level_str.upper() in logging._nameToLevel:
return logging._nameToLevel[env_level_str.upper()]
else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
return _default_log_level
def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r""" r"""
Gets a standard logger with a stream hander to stdout. Configures root logger using a stdout stream handler with an explicit format.
""" """
formatter = logging.Formatter( global _default_handler
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger = logging.getLogger(name) with _thread_lock:
logger.setLevel(logging.INFO) if _default_handler:
logger.addHandler(handler) return
return logger formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def reset_logging() -> None: def get_logger(name: Optional[str] = None) -> "logging.Logger":
r""" r"""
Removes basic config of root logger. (unused in script) Returns a logger with the specified name. It it not supposed to be accessed externally.
""" """
root = logging.getLogger() if name is None:
list(map(root.removeHandler, root.handlers)) name = _get_library_name()
list(map(root.removeFilter, root.filters))
_configure_library_root_logger()
return logging.getLogger(name)
...@@ -79,9 +79,9 @@ def check_dependencies() -> None: ...@@ -79,9 +79,9 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
require_version("datasets>=2.16.0,<=2.20.0", "To fix: pip install datasets>=2.16.0,<=2.20.0") require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("accelerate>=0.30.1,<=0.32.0", "To fix: pip install accelerate>=0.30.1,<=0.32.0") require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
...@@ -156,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList": ...@@ -156,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0
def has_tokenized_data(path: "os.PathLike") -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:
r""" r"""
Checks if the path has a tokenized dataset. Checks if the path has a tokenized dataset.
...@@ -183,6 +195,9 @@ def is_gpu_or_npu_available() -> bool: ...@@ -183,6 +195,9 @@ def is_gpu_or_npu_available() -> bool:
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
r"""
Casts a torch tensor or a numpy array to a numpy array.
"""
if isinstance(inputs, torch.Tensor): if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu() inputs = inputs.cpu()
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4 if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
...@@ -194,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray": ...@@ -194,6 +209,9 @@ def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
def skip_check_imports() -> None: def skip_check_imports() -> None:
r"""
Avoids flash attention import error in custom model files.
"""
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
transformers.dynamic_module_utils.check_imports = get_relative_imports transformers.dynamic_module_utils.check_imports = get_relative_imports
......
...@@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version": ...@@ -38,6 +38,10 @@ def _get_package_version(name: str) -> "Version":
return version.parse("0.0.0") return version.parse("0.0.0")
def is_pyav_available():
return _is_package_available("av")
def is_fastapi_available(): def is_fastapi_available():
return _is_package_available("fastapi") return _is_package_available("fastapi")
...@@ -81,13 +85,3 @@ def is_uvicorn_available(): ...@@ -81,13 +85,3 @@ def is_uvicorn_available():
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")
@lru_cache
def is_vllm_version_greater_than_0_5_1():
return _get_package_version("vllm") >= version.parse("0.5.1")
...@@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur ...@@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r""" r"""
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """
......
...@@ -73,6 +73,10 @@ class DataArguments: ...@@ -73,6 +73,10 @@ class DataArguments:
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."}, metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_batch_size: int = field(
default=1000,
metadata={"help": "The number of examples in one group in pre-processing."},
)
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of processes to use for the pre-processing."}, metadata={"help": "The number of processes to use for the pre-processing."},
......
...@@ -15,23 +15,141 @@ ...@@ -15,23 +15,141 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch
from typing_extensions import Self from typing_extensions import Self
if TYPE_CHECKING: @dataclass
import torch class QuantizationArguments:
r"""
Arguments pertaining to the quantization method.
"""
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in bitsandbytes int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
@dataclass
class ProcessorArguments:
r"""
Arguments pertaining to the image processor.
"""
image_resolution: int = field(
default=512,
metadata={"help": "Keeps the height or width of image below this resolution."},
)
video_resolution: int = field(
default=128,
metadata={"help": "Keeps the height or width of video below this resolution."},
)
video_fps: float = field(
default=2.0,
metadata={"help": "The frames to sample per second for video inputs."},
)
video_maxlen: int = field(
default=64,
metadata={"help": "The maximum number of sampled frames for video inputs."},
)
@dataclass
class ExportArguments:
r"""
Arguments pertaining to the model export.
"""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
@dataclass
class VllmArguments:
r"""
Arguments pertaining to the vLLM worker.
"""
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
@dataclass @dataclass
class ModelArguments: class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
r""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
""" """
model_name_or_path: str = field( model_name_or_path: Optional[str] = field(
default=None,
metadata={ metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
}, },
...@@ -77,26 +195,6 @@ class ModelArguments: ...@@ -77,26 +195,6 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
...@@ -117,9 +215,13 @@ class ModelArguments: ...@@ -117,9 +215,13 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
visual_inputs: bool = field( use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
)
enable_liger_kernel: bool = field(
default=False, default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, metadata={"help": "Whether or not to enable liger kernel for faster training."},
) )
moe_aux_loss_coef: Optional[float] = field( moe_aux_loss_coef: Optional[float] = field(
default=None, default=None,
...@@ -145,22 +247,6 @@ class ModelArguments: ...@@ -145,22 +247,6 @@ class ModelArguments:
default="huggingface", default="huggingface",
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},
) )
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
offload_folder: str = field( offload_folder: str = field(
default="offload", default="offload",
metadata={"help": "Path to offload model weights."}, metadata={"help": "Path to offload model weights."},
...@@ -181,59 +267,38 @@ class ModelArguments: ...@@ -181,59 +267,38 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."}, metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
export_dir: Optional[str] = field( print_param_status: bool = field(
default=None, default=False,
metadata={"help": "Path to the directory to save the exported model."}, metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
) )
export_quantization_bit: Optional[int] = field( compute_dtype: Optional[torch.dtype] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the exported model."}, init=False,
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
) )
export_quantization_dataset: Optional[str] = field( device_map: Optional[Union[str, Dict[str, Any]]] = field(
default=None, default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, init=False,
) metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
) )
export_quantization_maxlen: int = field( model_max_length: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None, default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, init=False,
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
) )
print_param_status: bool = field( block_diag_attn: bool = field(
default=False, default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, init=False,
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
) )
def __post_init__(self): def __post_init__(self):
self.compute_dtype: Optional["torch.dtype"] = None if self.model_name_or_path is None:
self.device_map: Optional[Union[str, Dict[str, Any]]] = None raise ValueError("Please provide `model_name_or_path`.")
self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
...@@ -243,16 +308,18 @@ class ModelArguments: ...@@ -243,16 +308,18 @@ class ModelArguments:
if self.export_quantization_bit is not None and self.export_quantization_dataset is None: if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.") raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self: def copyfrom(cls, source: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() init_args, lazy_args = {}, {}
arg_dict.update(**kwargs) for attr in fields(source):
new_arg = cls(**arg_dict) if attr.init:
new_arg.compute_dtype = old_arg.compute_dtype init_args[attr.name] = getattr(source, attr.name)
new_arg.device_map = old_arg.device_map else:
new_arg.model_max_length = old_arg.model_max_length lazy_args[attr.name] = getattr(source, attr.name)
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg init_args.update(kwargs)
result = cls(**init_args)
for name, value in lazy_args.items():
setattr(result, name, value)
return result
...@@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments ...@@ -26,7 +26,7 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES from ..extras.constants import CHECKPOINT_NAMES
...@@ -57,7 +57,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non ...@@ -57,7 +57,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
...@@ -116,11 +116,14 @@ def _check_extra_dependencies( ...@@ -116,11 +116,14 @@ def _check_extra_dependencies(
if model_args.use_unsloth: if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
if model_args.mixture_of_depths is not None: if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") require_version("vllm>=0.4.3,<=0.6.2", "To fix: pip install vllm>=0.4.3,<=0.6.2")
if finetuning_args.use_galore: if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch") require_version("galore_torch", "To fix: pip install galore_torch")
...@@ -212,11 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -212,11 +215,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
): ):
raise ValueError("Please specify dataset for evaluation.") raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate and data_args.eval_dataset is None: if training_args.predict_with_generate:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.") if is_deepspeed_zero3_enabled():
raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.")
if data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if training_args.predict_with_generate and finetuning_args.compute_accuracy: if finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto": if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.") raise ValueError("Cannot use device map for quantized models in training.")
...@@ -225,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -225,7 +232,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.") raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16: if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available(): if not (is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())):
raise ValueError("This device does not support `pure_bf16`.") raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
...@@ -250,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ...@@ -250,9 +257,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.infer_backend == "vllm": if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
...@@ -381,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: ...@@ -381,9 +385,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.") raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from llamafactory.train.tuner import run_exp from llamafactory.train.tuner import run_exp # use absolute import
def launch(): def launch():
......
...@@ -24,6 +24,7 @@ from ..extras.logging import get_logger ...@@ -24,6 +24,7 @@ from ..extras.logging import get_logger
from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
from .model_utils.visual import get_forbidden_modules, patch_target_modules
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -37,7 +38,6 @@ logger = get_logger(__name__) ...@@ -37,7 +38,6 @@ logger = get_logger(__name__)
def _setup_full_tuning( def _setup_full_tuning(
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
...@@ -46,13 +46,7 @@ def _setup_full_tuning( ...@@ -46,13 +46,7 @@ def _setup_full_tuning(
return return
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
forbidden_modules = set() forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules): if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32: if cast_trainable_params_to_fp32:
...@@ -63,7 +57,6 @@ def _setup_full_tuning( ...@@ -63,7 +57,6 @@ def _setup_full_tuning(
def _setup_freeze_tuning( def _setup_freeze_tuning(
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
...@@ -72,8 +65,8 @@ def _setup_freeze_tuning( ...@@ -72,8 +65,8 @@ def _setup_freeze_tuning(
return return
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs: if hasattr(model.config, "text_config"): # composite models
config = model.config.text_config config = getattr(model.config, "text_config")
else: else:
config = model.config config = model.config
...@@ -130,10 +123,7 @@ def _setup_freeze_tuning( ...@@ -130,10 +123,7 @@ def _setup_freeze_tuning(
trainable_layers.append(module_name) trainable_layers.append(module_name)
forbidden_modules = set() forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules forbidden_module in name for forbidden_module in forbidden_modules
...@@ -211,8 +201,7 @@ def _setup_lora_tuning( ...@@ -211,8 +201,7 @@ def _setup_lora_tuning(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower: target_modules = patch_target_modules(model.config, finetuning_args, target_modules)
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
if ( if (
finetuning_args.use_dora finetuning_args.use_dora
...@@ -303,9 +292,9 @@ def init_adapter( ...@@ -303,9 +292,9 @@ def init_adapter(
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze": elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) _setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora": elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning( model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
......
...@@ -21,11 +21,12 @@ from trl import AutoModelForCausalLMWithValueHead ...@@ -21,11 +21,12 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -60,11 +61,12 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: ...@@ -60,11 +61,12 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r""" r"""
Loads pretrained tokenizer. Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
...@@ -80,6 +82,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -80,6 +82,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
padding_side="right", padding_side="right",
**init_kwargs, **init_kwargs,
) )
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.new_special_tokens is not None: if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
...@@ -92,18 +96,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -92,18 +96,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
logger.warning("New tokens have been added, changed `resize_vocab` to True.") logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer) patch_tokenizer(tokenizer)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
except Exception as e:
logger.warning("Processor was not found: {}.".format(e))
processor = None
if model_args.visual_inputs: # Avoid load tokenizer, see:
try: # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) if processor is not None and "Processor" not in processor.__class__.__name__:
setattr(processor, "tokenizer", tokenizer)
except Exception:
raise ValueError(
"This multimodal LLM is not supported.\n"
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
)
else:
processor = None processor = None
return {"tokenizer": tokenizer, "processor": processor} return {"tokenizer": tokenizer, "processor": processor}
...@@ -130,6 +132,7 @@ def load_model( ...@@ -130,6 +132,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args) config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None model = None
lazy_load = False lazy_load = False
...@@ -145,12 +148,15 @@ def load_model( ...@@ -145,12 +148,15 @@ def load_model(
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else: else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs) if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else:
model = load_class.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert": if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args) model = convert_pretrained_model_to_mod(model, config, model_args)
......
...@@ -37,10 +37,11 @@ def configure_attn_implementation( ...@@ -37,10 +37,11 @@ def configure_attn_implementation(
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") if model_args.flash_attn != "fa2":
model_args.flash_attn = "fa2" logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else: else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") logger.warning("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled" model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers and PEFT library. # This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
# and the Unsloth library.
# https://github.com/unslothai/unsloth/blob/July-2024/unsloth/models/_utils.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,9 +19,9 @@ ...@@ -17,9 +19,9 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from functools import partial from functools import partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -36,8 +38,70 @@ if TYPE_CHECKING: ...@@ -36,8 +38,70 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(
ctx: "torch.autograd.Function",
forward_function: "torch.Module",
hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any],
) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func
def _gradient_checkpointing_enable( def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None: ) -> None:
r""" r"""
Activates gradient checkpointing for the current model. Activates gradient checkpointing for the current model.
...@@ -52,24 +116,18 @@ def _gradient_checkpointing_enable( ...@@ -52,24 +116,18 @@ def _gradient_checkpointing_enable(
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) if use_unsloth_gc:
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
def custom_gradient_checkpointing_func(func, *args, **kwargs): else:
module: "torch.nn.Module" = func.__self__ gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
gradient_checkpointing_func = get_custom_gradient_checkpointing_func(gradient_checkpointing_func)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads() self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
def _fp32_forward_post_hook( def _fp32_forward_post_hook(
...@@ -97,7 +155,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum ...@@ -97,7 +155,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
else: else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339 # According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) gradient_checkpointing_enable = partial(
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")
......
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
model_type = getattr(config, "model_type", None)
if model_type == "gemma":
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")
...@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( ...@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment