Unverified Commit 3badb021 authored by Shinichi Hemmi's avatar Shinichi Hemmi Committed by GitHub
Browse files

[Model] Add PLaMo2 (#14323)


Signed-off-by: default avatarShinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: default avatarshemmi <shemmi@preferred.jp>
Co-authored-by: default avatarKento Nozawa <nzw0301@preferred.jp>
Co-authored-by: default avatarHiroaki Mikami <mhiroaki@preferred.jp>
Co-authored-by: default avatarCalvin Metzger <metzger@preferred.jp>
parent fdcb850f
...@@ -400,8 +400,9 @@ steps: ...@@ -400,8 +400,9 @@ steps:
- pytest -v -s models/test_transformers.py - pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py - pytest -v -s models/test_registry.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531 # V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
- label: Language Models Test (Standard) # 32min - label: Language Models Test (Standard) # 32min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
...@@ -411,6 +412,8 @@ steps: ...@@ -411,6 +412,8 @@ steps:
- tests/models/embedding/language - tests/models/embedding/language
- tests/models/encoder_decoder/language - tests/models/encoder_decoder/language
commands: commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model - pytest -v -s models/embedding/language -m core_model
...@@ -422,6 +425,8 @@ steps: ...@@ -422,6 +425,8 @@ steps:
- tests/models/embedding/language - tests/models/embedding/language
- tests/models/encoder_decoder/language - tests/models/encoder_decoder/language
commands: commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model' - pytest -v -s models/embedding/language -m 'not core_model'
......
...@@ -497,6 +497,11 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -497,6 +497,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. * `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc.
* *
* ✅︎ * ✅︎
- * `Plamo2ForCausalLM`
* PLaMo2
* `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc.
*
*
- * `QWenLMHeadModel` - * `QWenLMHeadModel`
* Qwen * Qwen
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. * `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.
......
...@@ -27,6 +27,7 @@ torch==2.6.0 ...@@ -27,6 +27,7 @@ torch==2.6.0
torchaudio==2.6.0 torchaudio==2.6.0
torchvision==0.21.0 torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test mistral_common[opencv] >= 1.5.4 # required for pixtral test
num2words # required for smolvlm test num2words # required for smolvlm test
......
...@@ -111,6 +111,7 @@ einops==0.8.0 ...@@ -111,6 +111,7 @@ einops==0.8.0
# via # via
# -r requirements/test.in # -r requirements/test.in
# encodec # encodec
# mamba-ssm
# vector-quantize-pytorch # vector-quantize-pytorch
# vocos # vocos
einx==0.3.0 einx==0.3.0
...@@ -233,6 +234,8 @@ lxml==5.3.0 ...@@ -233,6 +234,8 @@ lxml==5.3.0
# via # via
# blobfile # blobfile
# sacrebleu # sacrebleu
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich # via rich
markupsafe==3.0.2 markupsafe==3.0.2
...@@ -268,6 +271,8 @@ mypy-extensions==1.0.0 ...@@ -268,6 +271,8 @@ mypy-extensions==1.0.0
# via black # via black
networkx==3.2.1 networkx==3.2.1
# via torch # via torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1 nltk==3.9.1
# via rouge-score # via rouge-score
num2words==0.5.14 num2words==0.5.14
...@@ -360,6 +365,7 @@ packaging==24.1 ...@@ -360,6 +365,7 @@ packaging==24.1
# fastparquet # fastparquet
# huggingface-hub # huggingface-hub
# lazy-loader # lazy-loader
# mamba-ssm
# matplotlib # matplotlib
# peft # peft
# plotly # plotly
...@@ -571,6 +577,7 @@ sentencepiece==0.2.0 ...@@ -571,6 +577,7 @@ sentencepiece==0.2.0
# via mistral-common # via mistral-common
setuptools==75.8.0 setuptools==75.8.0
# via # via
# mamba-ssm
# pytablewriter # pytablewriter
# torch # torch
shellingham==1.5.4 shellingham==1.5.4
...@@ -627,6 +634,7 @@ torch==2.6.0 ...@@ -627,6 +634,7 @@ torch==2.6.0
# encodec # encodec
# fastsafetensors # fastsafetensors
# lm-eval # lm-eval
# mamba-ssm
# peft # peft
# runai-model-streamer # runai-model-streamer
# sentence-transformers # sentence-transformers
...@@ -664,6 +672,7 @@ transformers==4.51.1 ...@@ -664,6 +672,7 @@ transformers==4.51.1
# -r requirements/test.in # -r requirements/test.in
# genai-perf # genai-perf
# lm-eval # lm-eval
# mamba-ssm
# peft # peft
# sentence-transformers # sentence-transformers
# transformers-stream-generator # transformers-stream-generator
......
...@@ -9,9 +9,15 @@ from vllm.sampling_params import SamplingParams ...@@ -9,9 +9,15 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
# This test is for the hybrid models # This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"] MODELS = [
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
"pfnet/plamo-2-1b"
]
# Bamba at Fp32 is too big for the CI (L4 GPU). # Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] # MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
# Note: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
...@@ -25,21 +31,11 @@ def test_models( ...@@ -25,21 +31,11 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
# numeric error produces different generation # numeric error produces different generation
if "Bamba" in model: if "Bamba" in model:
example_prompts.pop(3) example_prompts.pop(3)
model_kwargs = { with hf_runner(model, dtype=dtype) as hf_model:
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
...@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling( ...@@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
# correctly for n > 1 decoding steps inside a # correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills # chunked prefill forward pass (where we have both prefills
# and decoding together ) # and decoding together )
if 'plamo-2' in model:
dtype = "float" # use a different dtype for plamo
sampling_params = SamplingParams(n=3, sampling_params = SamplingParams(n=3,
temperature=1, temperature=1,
seed=0, seed=0,
...@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, ...@@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
example_prompts.pop(3) example_prompts.pop(3)
example_prompts.pop(2) example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model: elif "Zamba2" in model:
example_prompts.pop(7) example_prompts.pop(7)
dtype = "half" dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)
model_kwargs = { with hf_runner(model, dtype=dtype) as hf_model:
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, with vllm_runner(model,
...@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding( ...@@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured # This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because # batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible # tensor dimensions aren't compatible
vllm_config = EngineArgs(model=model).create_engine_config() vllm_config = EngineArgs(model=model,
trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph( while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)): len(example_prompts)):
example_prompts.append(example_prompts[0]) example_prompts.append(example_prompts[0])
......
...@@ -204,6 +204,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -204,6 +204,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True), trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
......
...@@ -2838,6 +2838,13 @@ def _get_and_verify_dtype( ...@@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
else: else:
torch_dtype = config_dtype torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.info(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (current_platform.is_cpu() if (current_platform.is_cpu()
and current_platform.get_cpu_architecture() and current_platform.get_cpu_architecture()
...@@ -2867,6 +2874,11 @@ def _get_and_verify_dtype( ...@@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
"using float16 by default. Please specify `dtype` if you " "using float16 by default. Please specify `dtype` if you "
"want to use float16.") "want to use float16.")
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else: else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}") raise ValueError(f"Unknown dtype: {dtype}")
......
This diff is collapsed.
...@@ -99,6 +99,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -99,6 +99,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment