Unverified Commit 38327cf4 authored by Jennifer Zhao's avatar Jennifer Zhao Committed by GitHub
Browse files

[Model] Aya Vision (#15441)


Signed-off-by: default avatarJennifer Zhao <ai.jenniferzhao@gmail.com>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent dfa82e2a
......@@ -753,6 +753,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `AyaVisionForConditionalGeneration`
* Aya Vision
* T + I<sup>+</sup>
* `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc.
*
* ✅︎
* ✅︎
- * `Blip2ForConditionalGeneration`
* BLIP-2
* T + I<sup>E</sup>
......
......@@ -60,6 +60,28 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
)
# Aya Vision
def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "CohereForAI/aya-vision-8b"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
mm_processor_kwargs={"crop_to_patches": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompts = [
f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# BLIP-2
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -865,6 +887,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
model_example_map = {
"aria": run_aria,
"aya_vision": run_aya_vision,
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
......
......@@ -61,6 +61,41 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "CohereForAI/aya-vision-8b"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_deepseek_vl2(question: str,
image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny"
......@@ -526,6 +561,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_example_map = {
"aria": load_aria,
"aya_vision": load_aya_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
......
......@@ -158,6 +158,20 @@ VLM_TEST_SETTINGS = {
max_tokens=64,
marks=[large_gpu_mark(min_gb=64)],
),
"aya_vision": VLMTestInfo(
models=["CohereForAI/aya-vision-8b"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>What's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<image>What is the season?", # noqa: E501
}),
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
),
"blip2": VLMTestInfo(
# TODO: Change back to 2.7b once head_dim = 80 is supported
models=["Salesforce/blip2-opt-6.7b"],
......
......@@ -246,6 +246,7 @@ def _test_processing_correctness_mistral(
# yapf: disable
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
"CohereForAI/aya-vision-8b",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
......
......@@ -259,6 +259,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
......
......@@ -2716,6 +2716,10 @@ def _get_and_verify_max_len(
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)
# For Command-R / Cohere, Cohere2 / Aya Vision models
if tmp_max_len := getattr(hf_config, "model_max_length", None):
max_len_key = "model_max_length"
derived_max_model_len = tmp_max_len
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
......
......@@ -496,8 +496,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"skywork_chat", "NVLM_D", "h2ovl_chat"):
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
......
This diff is collapsed.
......@@ -161,6 +161,7 @@ _CROSS_ENCODER_MODELS = {
_MULTIMODAL_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"),
"AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment