Unverified Commit 3badb021 authored by Shinichi Hemmi's avatar Shinichi Hemmi Committed by GitHub
Browse files

[Model] Add PLaMo2 (#14323)


Signed-off-by: default avatarShinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: default avatarshemmi <shemmi@preferred.jp>
Co-authored-by: default avatarKento Nozawa <nzw0301@preferred.jp>
Co-authored-by: default avatarHiroaki Mikami <mhiroaki@preferred.jp>
Co-authored-by: default avatarCalvin Metzger <metzger@preferred.jp>
parent fdcb850f
......@@ -400,8 +400,9 @@ steps:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
- label: Language Models Test (Standard) # 32min
#mirror_hardwares: [amd]
......@@ -411,6 +412,8 @@ steps:
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model
......@@ -422,6 +425,8 @@ steps:
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
......
......@@ -497,6 +497,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc.
*
* ✅︎
- * `Plamo2ForCausalLM`
* PLaMo2
* `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc.
*
*
- * `QWenLMHeadModel`
* Qwen
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.
......
......@@ -27,6 +27,7 @@ torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
num2words # required for smolvlm test
......
......@@ -111,6 +111,7 @@ einops==0.8.0
# via
# -r requirements/test.in
# encodec
# mamba-ssm
# vector-quantize-pytorch
# vocos
einx==0.3.0
......@@ -233,6 +234,8 @@ lxml==5.3.0
# via
# blobfile
# sacrebleu
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
......@@ -268,6 +271,8 @@ mypy-extensions==1.0.0
# via black
networkx==3.2.1
# via torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1
# via rouge-score
num2words==0.5.14
......@@ -360,6 +365,7 @@ packaging==24.1
# fastparquet
# huggingface-hub
# lazy-loader
# mamba-ssm
# matplotlib
# peft
# plotly
......@@ -571,6 +577,7 @@ sentencepiece==0.2.0
# via mistral-common
setuptools==75.8.0
# via
# mamba-ssm
# pytablewriter
# torch
shellingham==1.5.4
......@@ -627,6 +634,7 @@ torch==2.6.0
# encodec
# fastsafetensors
# lm-eval
# mamba-ssm
# peft
# runai-model-streamer
# sentence-transformers
......@@ -664,6 +672,7 @@ transformers==4.51.1
# -r requirements/test.in
# genai-perf
# lm-eval
# mamba-ssm
# peft
# sentence-transformers
# transformers-stream-generator
......
......@@ -9,9 +9,15 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
MODELS = [
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
"pfnet/plamo-2-1b"
]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
# Note: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
@pytest.mark.parametrize("model", MODELS)
......@@ -25,21 +31,11 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# numeric error produces different generation
if "Bamba" in model:
example_prompts.pop(3)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
......@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
# correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills
# and decoding together )
if 'plamo-2' in model:
dtype = "float" # use a different dtype for plamo
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
......@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
......@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
vllm_config = EngineArgs(model=model).create_engine_config()
vllm_config = EngineArgs(model=model,
trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])
......
......@@ -204,6 +204,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
......
......@@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.info(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16
from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
......@@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
......
This diff is collapsed.
......@@ -99,6 +99,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment