Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8560a5b2
Unverified
Commit
8560a5b2
authored
Jul 23, 2025
by
Christian Pinto
Committed by
GitHub
Jul 23, 2025
Browse files
[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)
Signed-off-by:
Christian Pinto
<
christian.pinto@ibm.com
>
parent
316b1bf7
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
704 additions
and
238 deletions
+704
-238
examples/offline_inference/prithvi_geospatial_mae.py
examples/offline_inference/prithvi_geospatial_mae.py
+70
-175
requirements/test.in
requirements/test.in
+1
-0
requirements/test.txt
requirements/test.txt
+360
-14
tests/models/multimodal/pooling/test_prithvi_mae.py
tests/models/multimodal/pooling/test_prithvi_mae.py
+63
-0
vllm/config.py
vllm/config.py
+4
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-5
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+34
-0
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+57
-17
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+11
-2
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+1
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+12
-5
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+8
-5
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+10
-8
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+8
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+60
-0
No files found.
examples/offline_inference/prithvi_geospatial_mae.py
View file @
8560a5b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
"""
# noqa: E501
import
argparse
import
datetime
import
os
import
re
from
typing
import
Union
import
albumentations
import
numpy
as
np
import
rasterio
import
regex
as
re
import
torch
from
einops
import
rearrange
from
terratorch.datamodules
import
Sen1Floods11NonGeoDataModule
from
vllm
import
LLM
torch
.
set_default_dtype
(
torch
.
float16
)
NO_DATA
=
-
9999
NO_DATA_FLOAT
=
0.0001
OFFSET
=
0
PERCENTILE
=
99
model_config
=
"""{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"./model/config.json"
),
"w"
)
as
config_file
:
config_file
.
write
(
model_config
)
datamodule_config
=
{
"bands"
:
[
"BLUE"
,
"GREEN"
,
"RED"
,
"NIR_NARROW"
,
"SWIR_1"
,
"SWIR_2"
],
"batch_size"
:
16
,
...
...
@@ -138,28 +43,24 @@ datamodule_config = {
class
PrithviMAE
:
def
__init__
(
self
):
print
(
"Initializing PrithviMAE model"
)
self
.
llm
=
LLM
(
model
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"./model"
),
skip_tokenizer_init
=
True
,
dtype
=
"float32"
,
def
__init__
(
self
,
model
):
self
.
model
=
LLM
(
model
=
model
,
skip_tokenizer_init
=
True
,
dtype
=
"float16"
,
enforce_eager
=
True
)
def
run
(
self
,
input_data
,
location_coords
):
print
(
"################ Running inference on vLLM ##############"
)
# merge the inputs into one data structure
if
input_data
is
not
None
and
input_data
.
dtype
==
torch
.
float32
:
input_data
=
input_data
.
to
(
torch
.
float16
)
input_data
=
input_data
[
0
]
mm_data
=
{
"pixel_values"
:
torch
.
empty
(
0
)
if
input_data
is
None
else
input_data
,
"location_coords"
:
torch
.
empty
(
0
)
if
location_coords
is
None
else
location_coords
,
"pixel_values"
:
input_data
,
"location_coords"
:
location_coords
,
}
prompt
=
{
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
mm_data
}
outputs
=
self
.
llm
.
encode
(
prompt
,
use_tqdm
=
False
)
print
(
"################ Inference done (it took seconds) ##############"
)
outputs
=
self
.
model
.
encode
(
prompt
,
use_tqdm
=
False
)
return
outputs
[
0
].
outputs
.
data
...
...
@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
"""
Args:
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
with shape = (bands, H, W).
channels: list of indices representing RGB channels.
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
torch.Tensor with shape (num_channels, height, width)
for original image
"""
orig_img
=
orig_img
[
channels
,
...]
...
...
@@ -260,10 +162,10 @@ def load_example(
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the
images
in *file_paths*.
std: list containing std values for each band in the
images
in *file_paths*.
mean: list containing mean values for each band in the
images
in *file_paths*.
std: list containing std values for each band in the
images
in *file_paths*.
Returns:
np.array containing created example
...
...
@@ -308,7 +210,7 @@ def load_example(
print
(
f
"Could not extract timestamp for
{
file
}
(
{
e
}
)"
)
imgs
=
np
.
stack
(
imgs
,
axis
=
0
)
# num_frames, H, W, C
imgs
=
np
.
moveaxis
(
imgs
,
-
1
,
0
).
astype
(
"float32"
)
imgs
=
np
.
moveaxis
(
imgs
,
-
1
,
0
).
astype
(
"float32"
)
# C, num_frames, H, W
imgs
=
np
.
expand_dims
(
imgs
,
axis
=
0
)
# add batch di
return
imgs
,
temporal_coords
,
location_coords
,
metas
...
...
@@ -332,8 +234,10 @@ def run_model(
)
# Build sliding window
batch_size
=
1
batch
=
torch
.
tensor
(
input_data
,
device
=
"cpu"
)
# batch = torch.tensor(input_data, device="cpu")
batch
=
torch
.
tensor
(
input_data
)
windows
=
batch
.
unfold
(
3
,
img_size
,
img_size
).
unfold
(
4
,
img_size
,
img_size
)
h1
,
w1
=
windows
.
shape
[
3
:
5
]
windows
=
rearrange
(
...
...
@@ -344,18 +248,16 @@ def run_model(
num_batches
=
windows
.
shape
[
0
]
//
batch_size
if
windows
.
shape
[
0
]
>
batch_size
else
1
windows
=
torch
.
tensor_split
(
windows
,
num_batches
,
dim
=
0
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
if
temporal_coords
:
temporal_coords
=
torch
.
tensor
(
temporal_coords
,
device
=
device
).
unsqueeze
(
0
)
temporal_coords
=
torch
.
tensor
(
temporal_coords
).
unsqueeze
(
0
)
else
:
temporal_coords
=
None
if
location_coords
:
location_coords
=
torch
.
tensor
(
location_coords
[
0
]
,
device
=
device
).
unsqueeze
(
0
)
location_coords
=
torch
.
tensor
(
location_coords
[
0
]).
unsqueeze
(
0
)
else
:
location_coords
=
None
# Run
model
# Run
Prithvi-EO-V2-300M-TL-Sen1Floods11
pred_imgs
=
[]
for
x
in
windows
:
# Apply standardization
...
...
@@ -363,15 +265,7 @@ def run_model(
x
=
datamodule
.
aug
(
x
)[
"image"
]
with
torch
.
no_grad
():
x
=
x
.
to
(
device
)
pred
=
model
.
run
(
x
,
location_coords
=
location_coords
)
if
lightning_model
:
pred_lightning
=
lightning_model
(
x
,
temporal_coords
=
temporal_coords
,
location_coords
=
location_coords
)
pred_lightning
=
pred_lightning
.
output
.
detach
().
cpu
()
if
not
torch
.
equal
(
pred
,
pred_lightning
):
print
(
"Inference output is not equal"
)
y_hat
=
pred
.
argmax
(
dim
=
1
)
y_hat
=
torch
.
nn
.
functional
.
interpolate
(
...
...
@@ -403,52 +297,18 @@ def run_model(
return
pred_imgs
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"MAE run inference"
,
add_help
=
False
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
default
=
"./India_900498_S2Hand.tif"
,
help
=
"Path to the file."
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Path to the directory where to save outputs."
,
)
parser
.
add_argument
(
"--input_indices"
,
default
=
[
1
,
2
,
3
,
8
,
11
,
12
],
type
=
int
,
nargs
=
"+"
,
help
=
"0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data."
,
)
parser
.
add_argument
(
"--rgb_outputs"
,
action
=
"store_true"
,
help
=
"If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved."
,
)
def
main
(
data_file
:
str
,
model
:
str
,
output_dir
:
str
,
rgb_outputs
:
bool
,
input_indices
:
list
[
int
]
=
None
,
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
# Load model ---------------------------------------------------------------
model_obj
=
PrithviMAE
()
model_obj
=
PrithviMAE
(
model
=
model
)
datamodule
=
generate_datamodule
()
img_size
=
256
# Size of Sen1Floods11
# Loading data -------------------------------------------------------------
img_size
=
512
# Size of Sen1Floods11
input_data
,
temporal_coords
,
location_coords
,
meta_data
=
load_example
(
file_paths
=
[
data_file
],
...
...
@@ -460,8 +320,6 @@ def main(
if
input_data
.
mean
()
>
1
:
input_data
=
input_data
/
10000
# Convert to range 0-1
# Running model ------------------------------------------------------------
channels
=
[
datamodule_config
[
"bands"
].
index
(
b
)
for
b
in
[
"RED"
,
"GREEN"
,
"BLUE"
]
]
# BGR -> RGB
...
...
@@ -469,7 +327,6 @@ def main(
pred
=
run_model
(
input_data
,
temporal_coords
,
location_coords
,
model_obj
,
datamodule
,
img_size
)
# Save pred
meta_data
.
update
(
count
=
1
,
dtype
=
"uint8"
,
compress
=
"lzw"
,
nodata
=
0
)
pred_file
=
os
.
path
.
join
(
...
...
@@ -487,6 +344,7 @@ def main(
orig_img
=
torch
.
Tensor
(
input_data
[
0
,
:,
0
,
...]),
channels
=
channels
,
)
rgb_orig
=
rgb_orig
.
to
(
torch
.
float32
)
pred
[
pred
==
0.0
]
=
np
.
nan
img_pred
=
rgb_orig
*
0.7
+
pred
*
0.3
...
...
@@ -503,9 +361,10 @@ def main(
# Save image rgb
if
rgb_outputs
:
name_suffix
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
rgb_file
=
os
.
path
.
join
(
output_dir
,
f
"original_rgb_
{
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
}
.tiff"
,
f
"original_rgb_
{
name_suffix
}
.tiff"
,
)
save_geotiff
(
image
=
_convert_np_uint8
(
rgb_orig
),
...
...
@@ -515,6 +374,42 @@ def main(
if
__name__
==
"__main__"
:
args
=
parse_args
()
parser
=
argparse
.
ArgumentParser
(
"MAE run inference"
,
add_help
=
False
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
default
=
"./India_900498_S2Hand.tif"
,
help
=
"Path to the file."
,
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
,
help
=
"Path to a checkpoint file to load from."
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Path to the directory where to save outputs."
,
)
parser
.
add_argument
(
"--input_indices"
,
default
=
[
1
,
2
,
3
,
8
,
11
,
12
],
type
=
int
,
nargs
=
"+"
,
help
=
"""
0-based indices of the six Prithvi channels to be selected from the input.
By default selects [1,2,3,8,11,12] for S2L1C data.
"""
,
)
parser
.
add_argument
(
"--rgb_outputs"
,
action
=
"store_true"
,
help
=
"If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved."
,
)
args
=
parser
.
parse_args
()
main
(
**
vars
(
args
))
requirements/test.in
View file @
8560a5b2
...
...
@@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
\ No newline at end of file
requirements/test.txt
View file @
8560a5b2
...
...
@@ -6,6 +6,10 @@ accelerate==1.0.1
# via
# lm-eval
# peft
aenum==3.1.16
# via lightly
affine==2.4.0
# via rasterio
aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.11
...
...
@@ -21,8 +25,18 @@ aiosignal==1.3.1
# via
# aiohttp
# ray
albucore==0.0.16
# via terratorch
albumentations==1.4.6
# via terratorch
alembic==1.16.4
# via mlflow
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.6.2.post1
# via
# httpx
...
...
@@ -34,10 +48,12 @@ arrow==1.3.0
attrs==24.2.0
# via
# aiohttp
# fiona
# hypothesis
# jsonlines
# jsonschema
# pytest-subtests
# rasterio
# referencing
audioread==3.0.1
# via librosa
...
...
@@ -46,9 +62,13 @@ backoff==2.2.1
# -r requirements/test.in
# schemathesis
bitsandbytes==0.46.1
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
black==24.10.0
# via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0
# via -r requirements/test.in
bm25s==0.2.13
...
...
@@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9
# via -r requirements/test.in
cachetools==5.5.2
# via google-auth
# via
# google-auth
# mlflow-skinny
certifi==2024.8.30
# via
# fiona
# httpcore
# httpx
# lightly
# pyogrio
# pyproj
# rasterio
# requests
cffi==1.17.1
# via soundfile
...
...
@@ -79,11 +106,28 @@ charset-normalizer==3.4.0
click==8.1.7
# via
# black
# click-plugins
# cligj
# fiona
# flask
# jiwer
# mlflow-skinny
# nltk
# rasterio
# ray
# schemathesis
# typer
# uvicorn
click-plugins==1.1.1.2
# via
# fiona
# rasterio
cligj==0.7.2
# via
# fiona
# rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6
# via
# sacrebleu
...
...
@@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
# via ray
cycler==0.12.1
# via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3
# via -r requirements/test.in
dataproperty==1.0.1
...
...
@@ -122,13 +168,21 @@ distlib==0.3.9
# via virtualenv
dnspython==2.7.0
# via email-validator
docker==7.1.0
# via mlflow
docopt==0.6.2
# via num2words
einops==0.8.0
docstring-parser==0.17.0
# via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1
# via
# -r requirements/test.in
# encodec
# mamba-ssm
# terratorch
# torchgeo
# vector-quantize-pytorch
# vocos
einx==0.3.0
...
...
@@ -141,6 +195,8 @@ eval-type-backport==0.2.2
# via mteb
evaluate==0.4.3
# via lm-eval
fastapi==0.116.1
# via mlflow-skinny
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
...
...
@@ -156,6 +212,10 @@ filelock==3.16.1
# torch
# transformers
# virtualenv
fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.54.1
# via matplotlib
fqdn==1.5.1
...
...
@@ -173,6 +233,8 @@ fsspec==2024.9.0
# evaluate
# fastparquet
# huggingface-hub
# lightning
# pytorch-lightning
# torch
ftfy==6.3.1
# via open-clip-torch
...
...
@@ -180,18 +242,41 @@ genai-perf==0.0.8
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
geopandas==1.0.1
# via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
# via mlflow-skinny
google-api-core==2.24.2
# via opencensus
google-auth==2.40.2
# via google-api-core
# via
# databricks-sdk
# google-api-core
googleapis-common-protos==1.70.0
# via google-api-core
graphene==3.4.3
# via mlflow
graphql-core==3.2.6
# via hypothesis-graphql
# via
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
greenlet==3.2.3
# via sqlalchemy
grpcio==1.71.0
# via ray
gunicorn==23.0.0
# via mlflow
h11==0.14.0
# via httpcore
# via
# httpcore
# uvicorn
h5py==3.13.0
# via terratorch
harfile==0.3.0
# via schemathesis
hf-xet==1.1.3
...
...
@@ -204,7 +289,7 @@ httpx==0.27.2
# via
# -r requirements/test.in
# schemathesis
huggingface-hub==0.33.
0
huggingface-hub==0.33.
1
# via
# -r requirements/test.in
# accelerate
...
...
@@ -212,13 +297,19 @@ huggingface-hub==0.33.0
# evaluate
# open-clip-torch
# peft
# segmentation-models-pytorch
# sentence-transformers
# terratorch
# timm
# tokenizers
# transformers
# vocos
humanize==4.11.0
# via runai-model-streamer
hydra-core==1.3.2
# via
# lightly
# lightning
hypothesis==6.131.0
# via
# hypothesis-graphql
...
...
@@ -236,6 +327,14 @@ idna==3.10
# jsonschema
# requests
# yarl
imageio==2.37.0
# via scikit-image
importlib-metadata==8.7.0
# via
# mlflow-skinny
# opentelemetry-api
importlib-resources==6.5.2
# via typeshed-client
inflect==5.6.2
# via datamodel-code-generator
iniconfig==2.0.0
...
...
@@ -244,9 +343,13 @@ isoduration==20.11.0
# via jsonschema
isort==5.13.2
# via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6
# via
# datamodel-code-generator
# flask
# mlflow
# torch
jiwer==3.0.5
# via -r requirements/test.in
...
...
@@ -259,6 +362,10 @@ joblib==1.4.2
# librosa
# nltk
# scikit-learn
jsonargparse==4.35.0
# via
# lightning
# terratorch
jsonlines==4.0.0
# via lm-eval
jsonpointer==3.0.0
...
...
@@ -277,12 +384,33 @@ kaleido==0.2.1
# via genai-perf
kiwisolver==1.4.7
# via matplotlib
kornia==0.8.1
# via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4
# via librosa
# via
# librosa
# scikit-image
libnacl==2.1.0
# via tensorizer
librosa==0.10.2.post1
# via -r requirements/test.in
lightly==1.5.20
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.5.1.post0
# via
# terratorch
# torchgeo
lightning-utilities==0.14.3
# via
# lightning
# pytorch-lightning
# torchmetrics
llvmlite==0.44.0
# via numba
lm-eval==0.4.8
...
...
@@ -291,16 +419,27 @@ lxml==5.3.0
# via
# blobfile
# sacrebleu
mako==1.3.10
# via alembic
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown==3.8.2
# via mlflow
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.1
# via
# flask
# jinja2
# mako
# werkzeug
matplotlib==3.9.2
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
# mlflow
# pycocotools
# torchgeo
mbstrdecoder==1.1.3
# via
# dataproperty
...
...
@@ -310,6 +449,10 @@ mdurl==0.1.2
# via markdown-it-py
mistral-common==1.8.0
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0
# via lm-eval
mpmath==1.3.0
...
...
@@ -328,10 +471,14 @@ multiprocess==0.70.16
# via
# datasets
# evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0
# via black
networkx==3.2.1
# via torch
# via
# scikit-image
# torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1
...
...
@@ -348,6 +495,8 @@ numpy==1.26.4
# via
# -r requirements/test.in
# accelerate
# albucore
# albumentations
# bitsandbytes
# bm25s
# contourpy
...
...
@@ -358,9 +507,15 @@ numpy==1.26.4
# evaluate
# fastparquet
# genai-perf
# geopandas
# h5py
# imageio
# librosa
# lightly
# lightly-utils
# matplotlib
# mistral-common
# mlflow
# mteb
# numba
# numexpr
...
...
@@ -368,18 +523,30 @@ numpy==1.26.4
# pandas
# patsy
# peft
# pycocotools
# pyogrio
# rasterio
# rioxarray
# rouge-score
# runai-model-streamer
# sacrebleu
# scikit-image
# scikit-learn
# scipy
# segmentation-models-pytorch
# shapely
# soxr
# statsmodels
# tensorboardx
# tensorizer
# tifffile
# torchgeo
# torchmetrics
# torchvision
# transformers
# tritonclient
# vocos
# xarray
nvidia-cublas-cu12==12.8.3.14
# via
# nvidia-cudnn-cu12
...
...
@@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
nvidia-nvtx-cu12==12.8.55
# via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
open-clip-torch==2.32.0
# via -r requirements/test.in
opencensus==0.11.4
...
...
@@ -426,7 +597,18 @@ opencensus-context==0.1.3
opencv-python-headless==4.11.0.86
# via
# -r requirements/test.in
# albucore
# albumentations
# mistral-common
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-sdk==1.35.0
# via mlflow-skinny
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2
# via
# accelerate
...
...
@@ -435,26 +617,44 @@ packaging==24.2
# datasets
# evaluate
# fastparquet
# geopandas
# gunicorn
# huggingface-hub
# hydra-core
# kornia
# lazy-loader
# lightning
# lightning-utilities
# mamba-ssm
# matplotlib
# mlflow-skinny
# peft
# plotly
# pooch
# pyogrio
# pytest
# pytest-rerunfailures
# pytorch-lightning
# ray
# rioxarray
# scikit-image
# statsmodels
# tensorboardx
# torchmetrics
# transformers
# typepy
# xarray
pandas==2.2.3
# via
# datasets
# evaluate
# fastparquet
# genai-perf
# geopandas
# mlflow
# statsmodels
# torchgeo
# xarray
pathspec==0.12.1
# via black
pathvalidate==3.2.1
...
...
@@ -468,9 +668,14 @@ peft==0.13.2
pillow==10.4.0
# via
# genai-perf
# imageio
# lightly-utils
# matplotlib
# mistral-common
# scikit-image
# segmentation-models-pytorch
# sentence-transformers
# torchgeo
# torchvision
platformdirs==4.3.6
# via
...
...
@@ -489,6 +694,8 @@ portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
# via ray
propcache==0.2.0
...
...
@@ -499,8 +706,10 @@ protobuf==5.28.3
# via
# google-api-core
# googleapis-common-protos
# mlflow-skinny
# proto-plus
# ray
# tensorboardx
# tensorizer
psutil==6.1.0
# via
...
...
@@ -515,6 +724,7 @@ pyarrow==18.0.0
# via
# datasets
# genai-perf
# mlflow
pyasn1==0.6.1
# via
# pyasn1-modules
...
...
@@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycocotools==2.0.8
# via terratorch
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
...
...
@@ -532,8 +744,12 @@ pycryptodomex==3.22.0
pydantic==2.11.5
# via
# -r requirements/test.in
# albumentations
# datamodel-code-generator
# fastapi
# lightly
# mistral-common
# mlflow-skinny
# mteb
# pydantic-extra-types
# ray
...
...
@@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
# via mistral-common
pygments==2.18.0
# via rich
pyogrio==0.11.0
# via geopandas
pyparsing==3.2.0
# via matplotlib
# via
# matplotlib
# rasterio
pyproj==3.7.1
# via
# geopandas
# rioxarray
# torchgeo
pyrate-limiter==3.7.0
# via schemathesis
pystemmer==3.0.0
# via mteb
pytablewriter==1.2.0
# via lm-eval
pytest==8.3.
3
pytest==8.3.
5
# via
# -r requirements/test.in
# buildkite-test-collector
...
...
@@ -564,6 +789,7 @@ pytest==8.3.3
# pytest-subtests
# pytest-timeout
# schemathesis
# terratorch
pytest-asyncio==0.24.0
# via -r requirements/test.in
pytest-forked==1.6.0
...
...
@@ -578,15 +804,23 @@ pytest-subtests==0.14.1
# via schemathesis
pytest-timeout==2.3.1
# via -r requirements/test.in
python-box==7.3.2
# via terratorch
python-dateutil==2.9.0.post0
# via
# arrow
# botocore
# graphene
# lightly
# matplotlib
# pandas
# typepy
python-rapidjson==1.20
# via tritonclient
pytorch-lightning==2.5.2
# via
# lightly
# lightning
pytrec-eval-terrier==0.5.7
# via mteb
pytz==2024.2
...
...
@@ -596,11 +830,17 @@ pytz==2024.2
pyyaml==6.0.2
# via
# accelerate
# albumentations
# datamodel-code-generator
# datasets
# genai-perf
# huggingface-hub
# jsonargparse
# lightning
# mlflow-skinny
# omegaconf
# peft
# pytorch-lightning
# ray
# responses
# schemathesis
...
...
@@ -609,6 +849,11 @@ pyyaml==6.0.2
# vocos
rapidfuzz==3.12.1
# via jiwer
rasterio==1.4.3
# via
# rioxarray
# terratorch
# torchgeo
ray==2.43.0
# via -r requirements/test.in
redis==5.2.0
...
...
@@ -627,12 +872,16 @@ regex==2024.9.11
requests==2.32.3
# via
# buildkite-test-collector
# databricks-sdk
# datasets
# docker
# evaluate
# google-api-core
# huggingface-hub
# lightly
# lm-eval
# mistral-common
# mlflow-skinny
# mteb
# pooch
# ray
...
...
@@ -650,8 +899,11 @@ rfc3987==1.3.8
rich==13.9.4
# via
# genai-perf
# lightning
# mteb
# typer
rioxarray==0.19.0
# via terratorch
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
...
...
@@ -660,6 +912,8 @@ rpds-py==0.20.1
# referencing
rsa==4.9.1
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.11.0
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
...
...
@@ -677,21 +931,32 @@ safetensors==0.4.5
# transformers
schemathesis==3.39.15
# via -r requirements/test.in
scikit-image==0.25.2
# via albumentations
scikit-learn==1.5.2
# via
# albumentations
# librosa
# lm-eval
# mlflow
# mteb
# sentence-transformers
scipy==1.13.1
# via
# albumentations
# bm25s
# librosa
# mlflow
# mteb
# scikit-image
# scikit-learn
# sentence-transformers
# statsmodels
# vocos
segmentation-models-pytorch==0.4.0
# via
# terratorch
# torchgeo
sentence-transformers==3.2.1
# via
# -r requirements/test.in
...
...
@@ -700,21 +965,30 @@ sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3
# via
# lightning-utilities
# mamba-ssm
# pytablewriter
# torch
# triton
shapely==2.1.1
# via
# geopandas
# torchgeo
shellingham==1.5.4
# via typer
six==1.16.0
# via
# junit-xml
# lightly
# opencensus
# python-dateutil
# rfc3339-validator
# rouge-score
# segmentation-models-pytorch
smart-open==7.1.0
# via ray
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via
# anyio
...
...
@@ -727,10 +1001,17 @@ soundfile==0.12.1
# librosa
soxr==0.5.0.post1
# via librosa
sqlalchemy==2.0.41
# via
# alembic
# mlflow
sqlitedict==2.1.0
# via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.46.2
# via
# fastapi
# schemathesis
# starlette-testclient
starlette-testclient==0.4.1
...
...
@@ -751,18 +1032,29 @@ tenacity==9.0.0
# via
# lm-eval
# plotly
tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
terratorch==1.1rc2
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
# terratorch
tiktoken==0.7.0
# via
# lm-eval
# mistral-common
timm==1.0.1
1
timm==1.0.1
5
# via
# -r requirements/test.in
# open-clip-torch
# segmentation-models-pytorch
# terratorch
# torchgeo
tokenizers==0.21.1
# via
# -r requirements/test.in
...
...
@@ -776,18 +1068,28 @@ torch==2.7.1+cu128
# -r requirements/test.in
# accelerate
# bitsandbytes
# efficientnet-pytorch
# encodec
# fastsafetensors
# kornia
# lightly
# lightning
# lm-eval
# mamba-ssm
# mteb
# open-clip-torch
# peft
# pretrainedmodels
# pytorch-lightning
# runai-model-streamer
# segmentation-models-pytorch
# sentence-transformers
# tensorizer
# terratorch
# timm
# torchaudio
# torchgeo
# torchmetrics
# torchvision
# vector-quantize-pytorch
# vocos
...
...
@@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128
# -r requirements/test.in
# encodec
# vocos
torchgeo==0.7.0
# via terratorch
torchmetrics==1.7.4
# via
# lightning
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.22.1+cu128
# via
# -r requirements/test.in
# lightly
# open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch
# terratorch
# timm
# torchgeo
tqdm==4.66.6
# via
# datasets
# evaluate
# huggingface-hub
# lightly
# lightning
# lm-eval
# mteb
# nltk
# open-clip-torch
# peft
# pqdm
# pretrainedmodels
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers
# tqdm-multiprocess
# transformers
...
...
@@ -843,18 +1163,34 @@ typer==0.15.2
# via fastsafetensors
types-python-dateutil==2.9.0.20241206
# via arrow
typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.12.2
# via
# albumentations
# alembic
# fastapi
# graphene
# huggingface-hub
# librosa
# lightning
# lightning-utilities
# mistral-common
# mlflow-skinny
# mteb
# opentelemetry-api
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pqdm
# pydantic
# pydantic-core
# pydantic-extra-types
# pytorch-lightning
# sqlalchemy
# torch
# torchgeo
# typer
# typeshed-client
# typing-inspection
typing-inspection==0.4.1
# via pydantic
...
...
@@ -866,9 +1202,13 @@ urllib3==2.2.3
# via
# blobfile
# botocore
# docker
# lightly
# requests
# responses
# tritonclient
uvicorn==0.35.0
# via mlflow-skinny
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
...
...
@@ -880,11 +1220,15 @@ wcwidth==0.2.13
webcolors==24.11.1
# via jsonschema
werkzeug==3.1.3
# via schemathesis
# via
# flask
# schemathesis
word2number==1.1
# via lm-eval
wrapt==1.17.2
# via smart-open
xarray==2025.7.1
# via rioxarray
xxhash==3.5.0
# via
# datasets
...
...
@@ -893,5 +1237,7 @@ yarl==1.17.1
# via
# aiohttp
# schemathesis
zipp==3.23.0
# via importlib-metadata
zstandard==0.23.0
# via lm-eval
tests/models/multimodal/pooling/test_prithvi_mae.py
0 → 100644
View file @
8560a5b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.utils
import
set_default_torch_num_threads
from
....conftest
import
VllmRunner
def
generate_test_mm_data
():
mm_data
=
{
"pixel_values"
:
torch
.
full
((
6
,
512
,
512
),
1.0
,
dtype
=
torch
.
float16
),
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
,
dtype
=
torch
.
float16
),
}
return
mm_data
def
_run_test
(
vllm_runner
:
type
[
VllmRunner
],
model
:
str
,
)
->
None
:
prompt
=
[
{
# This model deals with no text input
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
generate_test_mm_data
(),
}
for
_
in
range
(
10
)
]
with
(
set_default_torch_num_threads
(
1
),
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
torch
.
float16
,
enforce_eager
=
True
,
skip_tokenizer_init
=
True
,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs
=
32
,
)
as
vllm_model
,
):
vllm_model
.
encode
(
prompt
)
MODELS
=
[
"christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
]
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
def
test_models_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
)
->
None
:
_run_test
(
vllm_runner
,
model
,
)
vllm/config.py
View file @
8560a5b2
...
...
@@ -651,6 +651,8 @@ class ModelConfig:
self
.
original_max_model_len
=
self
.
max_model_len
self
.
max_model_len
=
self
.
get_and_verify_max_len
(
self
.
max_model_len
)
self
.
multimodal_config
=
self
.
_init_multimodal_config
()
self
.
model_supports_multimodal_raw_input
=
(
self
.
registry
.
supports_multimodal_raw_input
(
self
.
architectures
))
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
...
...
@@ -1243,10 +1245,10 @@ class ModelConfig:
return
self
.
get_hf_config_sliding_window
()
def
get_vocab_size
(
self
)
->
int
:
return
self
.
hf_text_config
.
vocab_size
return
getattr
(
self
.
hf_text_config
,
"
vocab_size
"
,
0
)
def
get_hidden_size
(
self
)
->
int
:
return
self
.
hf_text_config
.
hidden_size
return
getattr
(
self
.
hf_text_config
,
"
hidden_size
"
,
0
)
@
property
def
is_deepseek_mla
(
self
)
->
bool
:
...
...
vllm/engine/llm_engine.py
View file @
8560a5b2
...
...
@@ -238,14 +238,14 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
else
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
...
...
vllm/model_executor/models/interfaces.py
View file @
8560a5b2
...
...
@@ -136,6 +136,40 @@ def supports_multimodal(
return
getattr
(
model
,
"supports_multimodal"
,
False
)
@
runtime_checkable
class
SupportsMultiModalWithRawInput
(
SupportsMultiModal
,
Protocol
):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@
overload
def
supports_multimodal_raw_input
(
model
:
object
)
->
TypeIs
[
SupportsMultiModalWithRawInput
]:
...
@
overload
def
supports_multimodal_raw_input
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsMultiModalWithRawInput
]]:
...
def
supports_multimodal_raw_input
(
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
SupportsMultiModalWithRawInput
]],
TypeIs
[
SupportsMultiModalWithRawInput
]]:
return
getattr
(
model
,
"supports_multimodal_raw_input"
,
False
)
@
runtime_checkable
class
SupportsScoreTemplate
(
Protocol
):
"""The interface required for all models that support score template."""
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
8560a5b2
...
...
@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Optional
,
Union
...
...
@@ -27,13 +28,14 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.pooler
import
(
AllPool
,
PoolerHead
,
PoolerIdentity
,
SimplePooler
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
SupportsMultiModal
,
SupportsV0Only
)
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
MultiModalEmbeddings
,
SupportsMultiModalWithRawInput
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
)
MultiModalFieldElem
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
,
PlaceholderRange
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptUpdate
)
...
...
@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return
{
"pixel_values"
:
torch
.
full
((
1
,
6
,
512
,
512
),
1.0
),
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
),
"pixel_values"
:
torch
.
full
((
6
,
512
,
512
),
1.0
,
dtype
=
torch
.
float16
),
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
,
dtype
=
torch
.
float16
),
}
...
...
@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
location_coords
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
shared
(
batch_size
=
1
,
modality
=
"image"
),
location_coords
=
MultiModalFieldConfig
.
shared
(
batch_size
=
1
,
modality
=
"image"
),
)
def
_get_prompt_updates
(
...
...
@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
for
k
,
v
in
mm_data
.
items
():
mm_kwargs
[
k
]
=
v
mm_placeholders
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
0
)]}
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
multimodal_kwargs_items
=
[
MultiModalKwargsItem
.
from_elems
([
MultiModalFieldElem
(
modality
=
"image"
,
key
=
key
,
data
=
data
,
field
=
MultiModalSharedField
(
1
),
)
for
key
,
data
in
mm_kwargs
.
items
()
])
]
return
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
prompt
,
prompt_token_ids
=
[
1
],
mm_kwargs
=
MultiModalKwargs
(
mm_kwarg
s
),
mm_kwargs
=
MultiModalKwargs
.
from_items
(
multimodal_kwargs_item
s
),
mm_hashes
=
None
,
mm_placeholders
=
{}
,
mm_placeholders
=
mm_placeholders
,
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
PrithviGeoSpatialMAEMultiModalProcessor
,
info
=
PrithviGeoSpatialMAEProcessingInfo
,
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
,
SupportsV0Only
):
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
,
)
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModalWithRawInput
):
"""Prithvi Masked Autoencoder"""
is_pooling_model
=
True
...
...
@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
raise
ValueError
(
"Only image modality is supported"
)
def
_instantiate_model
(
self
,
config
:
dict
)
->
Optional
[
nn
.
Module
]:
# We might be able/need to support different tasks with this same model
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
from
terratorch.cli_tools
import
SemanticSegmentationTask
task
=
SemanticSegmentationTask
(
config
[
"model_args"
],
config
[
"task_args"
][
"model_factory"
],
...
...
@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
scheduler_hparams
=
config
[
"scheduler_params"
],
plot_on_val
=
config
[
"task_args"
][
"plot_on_val"
],
freeze_decoder
=
config
[
"task_args"
][
"freeze_decoder"
],
freeze_backbone
=
config
[
"task_args"
][
"freeze_backbone"
])
freeze_backbone
=
config
[
"task_args"
][
"freeze_backbone"
],
)
return
task
.
model
else
:
...
...
@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
f
"Incorrect type of pixel_values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_values
=
torch
.
unbind
(
pixel_values
,
dim
=
0
)[
0
]
location_coords
=
kwargs
.
pop
(
"location_coords"
,
None
)
if
not
isinstance
(
location_coords
,
torch
.
Tensor
):
...
...
@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return
pixel_values
,
location_coords
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return
torch
.
empty
((
input_ids
.
shape
[
0
],
0
))
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/models/registry.py
View file @
8560a5b2
...
...
@@ -22,8 +22,8 @@ from vllm.logger import init_logger
from
.interfaces
import
(
has_inner_state
,
has_noops
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_
pp
,
supports_transcription
,
supports_v0_only
)
supports_multimodal
,
supports_
multimodal_raw_input
,
supports_pp
,
supports_transcription
,
supports_v0_only
)
from
.interfaces_base
import
is_text_generation_model
logger
=
init_logger
(
__name__
)
...
...
@@ -287,6 +287,7 @@ class _ModelInfo:
is_pooling_model
:
bool
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_multimodal_raw_input
:
bool
supports_pp
:
bool
has_inner_state
:
bool
is_attention_free
:
bool
...
...
@@ -304,6 +305,7 @@ class _ModelInfo:
is_pooling_model
=
True
,
# Can convert any model into a pooling model
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_multimodal_raw_input
=
supports_multimodal_raw_input
(
model
),
supports_pp
=
supports_pp
(
model
),
has_inner_state
=
has_inner_state
(
model
),
is_attention_free
=
is_attention_free
(
model
),
...
...
@@ -573,6 +575,13 @@ class _ModelRegistry:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_multimodal
def
supports_multimodal_raw_input
(
self
,
architectures
:
Union
[
str
,
list
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_multimodal_raw_input
def
is_pp_supported_model
(
self
,
architectures
:
Union
[
str
,
list
[
str
]],
...
...
vllm/multimodal/registry.py
View file @
8560a5b2
...
...
@@ -266,7 +266,7 @@ class MultiModalRegistry:
if
not
model_config
.
is_multimodal_model
:
raise
ValueError
(
f
"
{
model_config
.
model
}
is not a multimodal model"
)
if
tokenizer
is
None
:
if
tokenizer
is
None
and
not
model_config
.
skip_tokenizer_init
:
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
if
disable_cache
is
None
:
mm_config
=
model_config
.
get_multimodal_config
()
...
...
vllm/v1/engine/async_llm.py
View file @
8560a5b2
...
...
@@ -94,11 +94,14 @@ class AsyncLLM(EngineClient):
self
.
log_requests
=
log_requests
self
.
log_stats
=
log_stats
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
None
else
:
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
# Processor (converts Inputs --> EngineCoreRequests).
self
.
processor
=
Processor
(
...
...
@@ -525,6 +528,10 @@ class AsyncLLM(EngineClient):
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"Unable to get tokenizer because "
"skip_tokenizer_init is True"
)
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
vllm/v1/engine/llm_engine.py
View file @
8560a5b2
...
...
@@ -82,11 +82,14 @@ class LLMEngine:
self
.
dp_group
=
None
self
.
should_execute_dummy_batch
=
False
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
None
else
:
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
...
...
vllm/v1/engine/output_processor.py
View file @
8560a5b2
...
...
@@ -327,14 +327,16 @@ class OutputProcessor:
if
request_id
in
self
.
request_states
:
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
req_state
=
RequestState
.
from_new_request
(
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
request
.
lora_request
),
request
=
request
,
prompt
=
prompt
,
parent_req
=
parent_req
,
request_index
=
request_index
,
queue
=
queue
,
log_stats
=
self
.
log_stats
)
tokenizer
=
None
if
not
self
.
tokenizer
else
\
self
.
tokenizer
.
get_lora_tokenizer
(
request
.
lora_request
)
req_state
=
RequestState
.
from_new_request
(
tokenizer
=
tokenizer
,
request
=
request
,
prompt
=
prompt
,
parent_req
=
parent_req
,
request_index
=
request_index
,
queue
=
queue
,
log_stats
=
self
.
log_stats
)
self
.
request_states
[
request_id
]
=
req_state
self
.
lora_states
.
add_request
(
req_state
)
if
parent_req
:
...
...
vllm/v1/engine/processor.py
View file @
8560a5b2
...
...
@@ -380,7 +380,6 @@ class Processor:
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
model_config
=
self
.
model_config
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
not
prompt_ids
:
...
...
@@ -389,9 +388,14 @@ class Processor:
else
:
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
if
self
.
model_config
.
skip_tokenizer_init
:
tokenizer
=
None
else
:
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>
max_prompt_len
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
8560a5b2
...
...
@@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
is_pooling_model
=
model_config
.
pooler_config
is
not
None
self
.
model_supports_multimodal_raw_input
=
(
model_config
.
model_supports_multimodal_raw_input
)
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
...
...
@@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
0
:
return
self
.
attn_metadata_builders
[
0
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
...
...
@@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
def
_init_model_kwargs_for_multimodal_model
(
self
,
scheduler_output
:
Optional
[
"SchedulerOutput"
]
=
None
,
num_reqs
:
int
=
-
1
,
)
->
dict
[
str
,
Any
]:
model_kwargs
:
dict
[
str
,
Any
]
=
{}
if
self
.
model_supports_multimodal_raw_input
:
# This model requires the raw multimodal data in input.
if
scheduler_output
:
multi_modal_kwargs_list
=
[]
for
req
in
scheduler_output
.
scheduled_new_reqs
:
req_mm_inputs
=
req
.
mm_inputs
if
not
isinstance
(
req_mm_inputs
,
list
):
req_mm_inputs
=
list
(
req_mm_inputs
)
multi_modal_kwargs_list
.
extend
(
req_mm_inputs
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
else
:
# The only case where SchedulerOutput is None is for
# a dummy run let's get some dummy data.
dummy_data
=
[
self
.
mm_registry
.
get_decoder_dummy_data
(
model_config
=
self
.
model_config
,
seq_len
=
1
).
multi_modal_data
for
i
in
range
(
num_reqs
)
]
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
dummy_data
)
model_kwargs
.
update
(
multi_modal_kwargs
)
return
model_kwargs
def
_get_cumsum_and_arange
(
self
,
num_tokens
:
np
.
ndarray
,
...
...
@@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
model_kwargs
=
self
.
_init_model_kwargs_for_multimodal_model
(
scheduler_output
=
scheduler_output
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
=
input_ids
,
multimodal_embeddings
=
mm_embeds
or
None
,
)
# TODO(woosuk): Avoid the copy. Optimize.
self
.
inputs_embeds
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds
)
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
...
...
@@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{}
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_input_tokens
]
else
:
...
...
@@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
**
MultiModalKwargs
.
as_kwargs
(
model_kwargs
,
device
=
self
.
device
,
),
)
self
.
maybe_wait_for_kv_save
()
...
...
@@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens
):
model
=
self
.
model
if
self
.
is_multimodal_model
:
model_kwargs
=
self
.
_init_model_kwargs_for_multimodal_model
(
num_reqs
=
num_reqs
)
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_tokens
]
inputs_embeds
=
None
model_kwargs
=
{}
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
else
:
...
...
@@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
**
MultiModalKwargs
.
as_kwargs
(
model_kwargs
,
device
=
self
.
device
,
),
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
_
=
outputs
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment