Unverified Commit b1d92053 authored by zhou fan's avatar zhou fan Committed by GitHub
Browse files

[Model]: Add support for Aria model (#10514)


Signed-off-by: default avatarxffxff <1247714429@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 452a4e80
......@@ -476,6 +476,12 @@ Text Generation
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`AriaForConditionalGeneration`
- Aria
- T + I
- :code:`rhymes-ai/Aria`
-
- ✅︎
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- T + I\ :sup:`E`
......
......@@ -402,6 +402,23 @@ def run_idefics3(question: str, modality: str):
return llm, prompt, stop_token_ids
# Aria
def run_aria(question: str, modality: str):
assert modality == "image"
model_name = "rhymes-ai/Aria"
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16")
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return llm, prompt, stop_token_ids
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
......@@ -423,6 +440,7 @@ model_example_map = {
"molmo": run_molmo,
"glm4v": run_glm4v,
"idefics3": run_idefics3,
"aria": run_aria,
}
......
......@@ -321,6 +321,25 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
)
def load_aria(question, image_urls: List[str]) -> ModelRequestData:
model_name = "rhymes-ai/Aria"
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={"image": len(image_urls)})
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None)
model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
......@@ -330,6 +349,7 @@ model_example_map = {
"qwen_vl_chat": load_qwenvl_chat,
"mllama": load_mllama,
"idefics3": load_idefics3,
"aria": load_aria,
}
......
......@@ -43,6 +43,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
......
......@@ -412,6 +412,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return ""
if model_type == "idefics3":
return "<image>"
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
......
This diff is collapsed.
......@@ -133,6 +133,7 @@ _CROSS_ENCODER_MODELS = {
_MULTIMODAL_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
......
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig)
from transformers.models.llama.configuration_llama import LlamaConfig
class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"
class AriaMoELMConfig(LlamaConfig):
"""
Configuration class for AriaMoE language model.
This class extends the LlamaConfig to include additional parameters specific
to the Mixture of Experts (MoE) architecture.
"""
model_type = "aria_moe_lm"
def __init__(
self,
moe_intermediate_size: int = 4096,
moe_num_experts: int = 8,
moe_topk: int = 2,
moe_num_shared_experts: int = 2,
**kwargs,
):
"""
Initialize the AriaMoELMConfig.
Args:
moe_intermediate_size (int): The intermediate size for MoE layers.
Default is 4096.
moe_num_experts (int): The number of experts in the MoE layer.
Default is 8.
moe_topk (int): The number of top experts to route to for each
token. Default is 2.
moe_num_shared_experts (int): The number of shared experts. Default
is 2.
**kwargs: Additional keyword arguments to be passed to the parent
LlamaConfig.
"""
super().__init__(**kwargs)
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.moe_num_shared_experts = moe_num_shared_experts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment