Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8560a5b2
Unverified
Commit
8560a5b2
authored
Jul 23, 2025
by
Christian Pinto
Committed by
GitHub
Jul 23, 2025
Browse files
[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)
Signed-off-by:
Christian Pinto
<
christian.pinto@ibm.com
>
parent
316b1bf7
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
704 additions
and
238 deletions
+704
-238
examples/offline_inference/prithvi_geospatial_mae.py
examples/offline_inference/prithvi_geospatial_mae.py
+70
-175
requirements/test.in
requirements/test.in
+1
-0
requirements/test.txt
requirements/test.txt
+360
-14
tests/models/multimodal/pooling/test_prithvi_mae.py
tests/models/multimodal/pooling/test_prithvi_mae.py
+63
-0
vllm/config.py
vllm/config.py
+4
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-5
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+34
-0
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+57
-17
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+11
-2
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+1
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+12
-5
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+8
-5
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+10
-8
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+8
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+60
-0
No files found.
examples/offline_inference/prithvi_geospatial_mae.py
View file @
8560a5b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
"""
# noqa: E501
import
argparse
import
argparse
import
datetime
import
datetime
import
os
import
os
import
re
from
typing
import
Union
from
typing
import
Union
import
albumentations
import
albumentations
import
numpy
as
np
import
numpy
as
np
import
rasterio
import
rasterio
import
regex
as
re
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
terratorch.datamodules
import
Sen1Floods11NonGeoDataModule
from
terratorch.datamodules
import
Sen1Floods11NonGeoDataModule
from
vllm
import
LLM
from
vllm
import
LLM
torch
.
set_default_dtype
(
torch
.
float16
)
NO_DATA
=
-
9999
NO_DATA
=
-
9999
NO_DATA_FLOAT
=
0.0001
NO_DATA_FLOAT
=
0.0001
OFFSET
=
0
OFFSET
=
0
PERCENTILE
=
99
PERCENTILE
=
99
model_config
=
"""{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"./model/config.json"
),
"w"
)
as
config_file
:
config_file
.
write
(
model_config
)
datamodule_config
=
{
datamodule_config
=
{
"bands"
:
[
"BLUE"
,
"GREEN"
,
"RED"
,
"NIR_NARROW"
,
"SWIR_1"
,
"SWIR_2"
],
"bands"
:
[
"BLUE"
,
"GREEN"
,
"RED"
,
"NIR_NARROW"
,
"SWIR_1"
,
"SWIR_2"
],
"batch_size"
:
16
,
"batch_size"
:
16
,
...
@@ -138,28 +43,24 @@ datamodule_config = {
...
@@ -138,28 +43,24 @@ datamodule_config = {
class
PrithviMAE
:
class
PrithviMAE
:
def
__init__
(
self
):
def
__init__
(
self
,
model
):
print
(
"Initializing PrithviMAE model"
)
self
.
model
=
LLM
(
self
.
llm
=
LLM
(
model
=
model
,
skip_tokenizer_init
=
True
,
dtype
=
"float16"
,
enforce_eager
=
True
model
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"./model"
),
skip_tokenizer_init
=
True
,
dtype
=
"float32"
,
)
)
def
run
(
self
,
input_data
,
location_coords
):
def
run
(
self
,
input_data
,
location_coords
):
print
(
"################ Running inference on vLLM ##############"
)
# merge the inputs into one data structure
# merge the inputs into one data structure
if
input_data
is
not
None
and
input_data
.
dtype
==
torch
.
float32
:
input_data
=
input_data
.
to
(
torch
.
float16
)
input_data
=
input_data
[
0
]
mm_data
=
{
mm_data
=
{
"pixel_values"
:
torch
.
empty
(
0
)
if
input_data
is
None
else
input_data
,
"pixel_values"
:
input_data
,
"location_coords"
:
torch
.
empty
(
0
)
"location_coords"
:
location_coords
,
if
location_coords
is
None
else
location_coords
,
}
}
prompt
=
{
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
mm_data
}
prompt
=
{
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
mm_data
}
outputs
=
self
.
model
.
encode
(
prompt
,
use_tqdm
=
False
)
outputs
=
self
.
llm
.
encode
(
prompt
,
use_tqdm
=
False
)
print
(
"################ Inference done (it took seconds) ##############"
)
return
outputs
[
0
].
outputs
.
data
return
outputs
[
0
].
outputs
.
data
...
@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
...
@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
"""
"""
Args:
Args:
orig_img: torch.Tensor representing original image (reference)
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
with shape = (bands, H, W).
channels: list of indices representing RGB channels.
channels: list of indices representing RGB channels.
Returns:
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
torch.Tensor with shape (num_channels, height, width)
for original image
"""
"""
orig_img
=
orig_img
[
channels
,
...]
orig_img
=
orig_img
[
channels
,
...]
...
@@ -260,10 +162,10 @@ def load_example(
...
@@ -260,10 +162,10 @@ def load_example(
Args:
Args:
file_paths: list of file paths .
file_paths: list of file paths .
mean: list containing mean values for each band in the
images
mean: list containing mean values for each band in the
in *file_paths*.
images
in *file_paths*.
std: list containing std values for each band in the
images
std: list containing std values for each band in the
in *file_paths*.
images
in *file_paths*.
Returns:
Returns:
np.array containing created example
np.array containing created example
...
@@ -308,7 +210,7 @@ def load_example(
...
@@ -308,7 +210,7 @@ def load_example(
print
(
f
"Could not extract timestamp for
{
file
}
(
{
e
}
)"
)
print
(
f
"Could not extract timestamp for
{
file
}
(
{
e
}
)"
)
imgs
=
np
.
stack
(
imgs
,
axis
=
0
)
# num_frames, H, W, C
imgs
=
np
.
stack
(
imgs
,
axis
=
0
)
# num_frames, H, W, C
imgs
=
np
.
moveaxis
(
imgs
,
-
1
,
0
).
astype
(
"float32"
)
imgs
=
np
.
moveaxis
(
imgs
,
-
1
,
0
).
astype
(
"float32"
)
# C, num_frames, H, W
imgs
=
np
.
expand_dims
(
imgs
,
axis
=
0
)
# add batch di
imgs
=
np
.
expand_dims
(
imgs
,
axis
=
0
)
# add batch di
return
imgs
,
temporal_coords
,
location_coords
,
metas
return
imgs
,
temporal_coords
,
location_coords
,
metas
...
@@ -332,8 +234,10 @@ def run_model(
...
@@ -332,8 +234,10 @@ def run_model(
)
)
# Build sliding window
# Build sliding window
batch_size
=
1
batch_size
=
1
batch
=
torch
.
tensor
(
input_data
,
device
=
"cpu"
)
# batch = torch.tensor(input_data, device="cpu")
batch
=
torch
.
tensor
(
input_data
)
windows
=
batch
.
unfold
(
3
,
img_size
,
img_size
).
unfold
(
4
,
img_size
,
img_size
)
windows
=
batch
.
unfold
(
3
,
img_size
,
img_size
).
unfold
(
4
,
img_size
,
img_size
)
h1
,
w1
=
windows
.
shape
[
3
:
5
]
h1
,
w1
=
windows
.
shape
[
3
:
5
]
windows
=
rearrange
(
windows
=
rearrange
(
...
@@ -344,18 +248,16 @@ def run_model(
...
@@ -344,18 +248,16 @@ def run_model(
num_batches
=
windows
.
shape
[
0
]
//
batch_size
if
windows
.
shape
[
0
]
>
batch_size
else
1
num_batches
=
windows
.
shape
[
0
]
//
batch_size
if
windows
.
shape
[
0
]
>
batch_size
else
1
windows
=
torch
.
tensor_split
(
windows
,
num_batches
,
dim
=
0
)
windows
=
torch
.
tensor_split
(
windows
,
num_batches
,
dim
=
0
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
if
temporal_coords
:
if
temporal_coords
:
temporal_coords
=
torch
.
tensor
(
temporal_coords
,
device
=
device
).
unsqueeze
(
0
)
temporal_coords
=
torch
.
tensor
(
temporal_coords
).
unsqueeze
(
0
)
else
:
else
:
temporal_coords
=
None
temporal_coords
=
None
if
location_coords
:
if
location_coords
:
location_coords
=
torch
.
tensor
(
location_coords
[
0
]
,
device
=
device
).
unsqueeze
(
0
)
location_coords
=
torch
.
tensor
(
location_coords
[
0
]).
unsqueeze
(
0
)
else
:
else
:
location_coords
=
None
location_coords
=
None
# Run
model
# Run
Prithvi-EO-V2-300M-TL-Sen1Floods11
pred_imgs
=
[]
pred_imgs
=
[]
for
x
in
windows
:
for
x
in
windows
:
# Apply standardization
# Apply standardization
...
@@ -363,15 +265,7 @@ def run_model(
...
@@ -363,15 +265,7 @@ def run_model(
x
=
datamodule
.
aug
(
x
)[
"image"
]
x
=
datamodule
.
aug
(
x
)[
"image"
]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
x
=
x
.
to
(
device
)
pred
=
model
.
run
(
x
,
location_coords
=
location_coords
)
pred
=
model
.
run
(
x
,
location_coords
=
location_coords
)
if
lightning_model
:
pred_lightning
=
lightning_model
(
x
,
temporal_coords
=
temporal_coords
,
location_coords
=
location_coords
)
pred_lightning
=
pred_lightning
.
output
.
detach
().
cpu
()
if
not
torch
.
equal
(
pred
,
pred_lightning
):
print
(
"Inference output is not equal"
)
y_hat
=
pred
.
argmax
(
dim
=
1
)
y_hat
=
pred
.
argmax
(
dim
=
1
)
y_hat
=
torch
.
nn
.
functional
.
interpolate
(
y_hat
=
torch
.
nn
.
functional
.
interpolate
(
...
@@ -403,52 +297,18 @@ def run_model(
...
@@ -403,52 +297,18 @@ def run_model(
return
pred_imgs
return
pred_imgs
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"MAE run inference"
,
add_help
=
False
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
default
=
"./India_900498_S2Hand.tif"
,
help
=
"Path to the file."
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Path to the directory where to save outputs."
,
)
parser
.
add_argument
(
"--input_indices"
,
default
=
[
1
,
2
,
3
,
8
,
11
,
12
],
type
=
int
,
nargs
=
"+"
,
help
=
"0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data."
,
)
parser
.
add_argument
(
"--rgb_outputs"
,
action
=
"store_true"
,
help
=
"If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved."
,
)
def
main
(
def
main
(
data_file
:
str
,
data_file
:
str
,
model
:
str
,
output_dir
:
str
,
output_dir
:
str
,
rgb_outputs
:
bool
,
rgb_outputs
:
bool
,
input_indices
:
list
[
int
]
=
None
,
input_indices
:
list
[
int
]
=
None
,
):
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
# Load model ---------------------------------------------------------------
model_obj
=
PrithviMAE
(
model
=
model
)
model_obj
=
PrithviMAE
()
datamodule
=
generate_datamodule
()
datamodule
=
generate_datamodule
()
img_size
=
256
# Size of Sen1Floods11
img_size
=
512
# Size of Sen1Floods11
# Loading data -------------------------------------------------------------
input_data
,
temporal_coords
,
location_coords
,
meta_data
=
load_example
(
input_data
,
temporal_coords
,
location_coords
,
meta_data
=
load_example
(
file_paths
=
[
data_file
],
file_paths
=
[
data_file
],
...
@@ -460,8 +320,6 @@ def main(
...
@@ -460,8 +320,6 @@ def main(
if
input_data
.
mean
()
>
1
:
if
input_data
.
mean
()
>
1
:
input_data
=
input_data
/
10000
# Convert to range 0-1
input_data
=
input_data
/
10000
# Convert to range 0-1
# Running model ------------------------------------------------------------
channels
=
[
channels
=
[
datamodule_config
[
"bands"
].
index
(
b
)
for
b
in
[
"RED"
,
"GREEN"
,
"BLUE"
]
datamodule_config
[
"bands"
].
index
(
b
)
for
b
in
[
"RED"
,
"GREEN"
,
"BLUE"
]
]
# BGR -> RGB
]
# BGR -> RGB
...
@@ -469,7 +327,6 @@ def main(
...
@@ -469,7 +327,6 @@ def main(
pred
=
run_model
(
pred
=
run_model
(
input_data
,
temporal_coords
,
location_coords
,
model_obj
,
datamodule
,
img_size
input_data
,
temporal_coords
,
location_coords
,
model_obj
,
datamodule
,
img_size
)
)
# Save pred
# Save pred
meta_data
.
update
(
count
=
1
,
dtype
=
"uint8"
,
compress
=
"lzw"
,
nodata
=
0
)
meta_data
.
update
(
count
=
1
,
dtype
=
"uint8"
,
compress
=
"lzw"
,
nodata
=
0
)
pred_file
=
os
.
path
.
join
(
pred_file
=
os
.
path
.
join
(
...
@@ -487,6 +344,7 @@ def main(
...
@@ -487,6 +344,7 @@ def main(
orig_img
=
torch
.
Tensor
(
input_data
[
0
,
:,
0
,
...]),
orig_img
=
torch
.
Tensor
(
input_data
[
0
,
:,
0
,
...]),
channels
=
channels
,
channels
=
channels
,
)
)
rgb_orig
=
rgb_orig
.
to
(
torch
.
float32
)
pred
[
pred
==
0.0
]
=
np
.
nan
pred
[
pred
==
0.0
]
=
np
.
nan
img_pred
=
rgb_orig
*
0.7
+
pred
*
0.3
img_pred
=
rgb_orig
*
0.7
+
pred
*
0.3
...
@@ -503,9 +361,10 @@ def main(
...
@@ -503,9 +361,10 @@ def main(
# Save image rgb
# Save image rgb
if
rgb_outputs
:
if
rgb_outputs
:
name_suffix
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
rgb_file
=
os
.
path
.
join
(
rgb_file
=
os
.
path
.
join
(
output_dir
,
output_dir
,
f
"original_rgb_
{
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
}
.tiff"
,
f
"original_rgb_
{
name_suffix
}
.tiff"
,
)
)
save_geotiff
(
save_geotiff
(
image
=
_convert_np_uint8
(
rgb_orig
),
image
=
_convert_np_uint8
(
rgb_orig
),
...
@@ -515,6 +374,42 @@ def main(
...
@@ -515,6 +374,42 @@ def main(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
parser
=
argparse
.
ArgumentParser
(
"MAE run inference"
,
add_help
=
False
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
default
=
"./India_900498_S2Hand.tif"
,
help
=
"Path to the file."
,
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
,
help
=
"Path to a checkpoint file to load from."
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Path to the directory where to save outputs."
,
)
parser
.
add_argument
(
"--input_indices"
,
default
=
[
1
,
2
,
3
,
8
,
11
,
12
],
type
=
int
,
nargs
=
"+"
,
help
=
"""
0-based indices of the six Prithvi channels to be selected from the input.
By default selects [1,2,3,8,11,12] for S2L1C data.
"""
,
)
parser
.
add_argument
(
"--rgb_outputs"
,
action
=
"store_true"
,
help
=
"If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved."
,
)
args
=
parser
.
parse_args
()
main
(
**
vars
(
args
))
main
(
**
vars
(
args
))
requirements/test.in
View file @
8560a5b2
...
@@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
...
@@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
\ No newline at end of file
requirements/test.txt
View file @
8560a5b2
...
@@ -6,6 +6,10 @@ accelerate==1.0.1
...
@@ -6,6 +6,10 @@ accelerate==1.0.1
# via
# via
# lm-eval
# lm-eval
# peft
# peft
aenum==3.1.16
# via lightly
affine==2.4.0
# via rasterio
aiohappyeyeballs==2.4.3
aiohappyeyeballs==2.4.3
# via aiohttp
# via aiohttp
aiohttp==3.10.11
aiohttp==3.10.11
...
@@ -21,8 +25,18 @@ aiosignal==1.3.1
...
@@ -21,8 +25,18 @@ aiosignal==1.3.1
# via
# via
# aiohttp
# aiohttp
# ray
# ray
albucore==0.0.16
# via terratorch
albumentations==1.4.6
# via terratorch
alembic==1.16.4
# via mlflow
annotated-types==0.7.0
annotated-types==0.7.0
# via pydantic
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.6.2.post1
anyio==4.6.2.post1
# via
# via
# httpx
# httpx
...
@@ -34,10 +48,12 @@ arrow==1.3.0
...
@@ -34,10 +48,12 @@ arrow==1.3.0
attrs==24.2.0
attrs==24.2.0
# via
# via
# aiohttp
# aiohttp
# fiona
# hypothesis
# hypothesis
# jsonlines
# jsonlines
# jsonschema
# jsonschema
# pytest-subtests
# pytest-subtests
# rasterio
# referencing
# referencing
audioread==3.0.1
audioread==3.0.1
# via librosa
# via librosa
...
@@ -46,9 +62,13 @@ backoff==2.2.1
...
@@ -46,9 +62,13 @@ backoff==2.2.1
# -r requirements/test.in
# -r requirements/test.in
# schemathesis
# schemathesis
bitsandbytes==0.46.1
bitsandbytes==0.46.1
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
black==24.10.0
black==24.10.0
# via datamodel-code-generator
# via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0
blobfile==3.0.0
# via -r requirements/test.in
# via -r requirements/test.in
bm25s==0.2.13
bm25s==0.2.13
...
@@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
...
@@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9
buildkite-test-collector==0.1.9
# via -r requirements/test.in
# via -r requirements/test.in
cachetools==5.5.2
cachetools==5.5.2
# via google-auth
# via
# google-auth
# mlflow-skinny
certifi==2024.8.30
certifi==2024.8.30
# via
# via
# fiona
# httpcore
# httpcore
# httpx
# httpx
# lightly
# pyogrio
# pyproj
# rasterio
# requests
# requests
cffi==1.17.1
cffi==1.17.1
# via soundfile
# via soundfile
...
@@ -79,11 +106,28 @@ charset-normalizer==3.4.0
...
@@ -79,11 +106,28 @@ charset-normalizer==3.4.0
click==8.1.7
click==8.1.7
# via
# via
# black
# black
# click-plugins
# cligj
# fiona
# flask
# jiwer
# jiwer
# mlflow-skinny
# nltk
# nltk
# rasterio
# ray
# ray
# schemathesis
# schemathesis
# typer
# typer
# uvicorn
click-plugins==1.1.1.2
# via
# fiona
# rasterio
cligj==0.7.2
# via
# fiona
# rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6
colorama==0.4.6
# via
# via
# sacrebleu
# sacrebleu
...
@@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
...
@@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
# via ray
# via ray
cycler==0.12.1
cycler==0.12.1
# via matplotlib
# via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3
datamodel-code-generator==0.26.3
# via -r requirements/test.in
# via -r requirements/test.in
dataproperty==1.0.1
dataproperty==1.0.1
...
@@ -122,13 +168,21 @@ distlib==0.3.9
...
@@ -122,13 +168,21 @@ distlib==0.3.9
# via virtualenv
# via virtualenv
dnspython==2.7.0
dnspython==2.7.0
# via email-validator
# via email-validator
docker==7.1.0
# via mlflow
docopt==0.6.2
docopt==0.6.2
# via num2words
# via num2words
einops==0.8.0
docstring-parser==0.17.0
# via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# encodec
# encodec
# mamba-ssm
# mamba-ssm
# terratorch
# torchgeo
# vector-quantize-pytorch
# vector-quantize-pytorch
# vocos
# vocos
einx==0.3.0
einx==0.3.0
...
@@ -141,6 +195,8 @@ eval-type-backport==0.2.2
...
@@ -141,6 +195,8 @@ eval-type-backport==0.2.2
# via mteb
# via mteb
evaluate==0.4.3
evaluate==0.4.3
# via lm-eval
# via lm-eval
fastapi==0.116.1
# via mlflow-skinny
fastparquet==2024.11.0
fastparquet==2024.11.0
# via genai-perf
# via genai-perf
fastrlock==0.8.2
fastrlock==0.8.2
...
@@ -156,6 +212,10 @@ filelock==3.16.1
...
@@ -156,6 +212,10 @@ filelock==3.16.1
# torch
# torch
# transformers
# transformers
# virtualenv
# virtualenv
fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.54.1
fonttools==4.54.1
# via matplotlib
# via matplotlib
fqdn==1.5.1
fqdn==1.5.1
...
@@ -173,6 +233,8 @@ fsspec==2024.9.0
...
@@ -173,6 +233,8 @@ fsspec==2024.9.0
# evaluate
# evaluate
# fastparquet
# fastparquet
# huggingface-hub
# huggingface-hub
# lightning
# pytorch-lightning
# torch
# torch
ftfy==6.3.1
ftfy==6.3.1
# via open-clip-torch
# via open-clip-torch
...
@@ -180,18 +242,41 @@ genai-perf==0.0.8
...
@@ -180,18 +242,41 @@ genai-perf==0.0.8
# via -r requirements/test.in
# via -r requirements/test.in
genson==1.3.0
genson==1.3.0
# via datamodel-code-generator
# via datamodel-code-generator
geopandas==1.0.1
# via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
# via mlflow-skinny
google-api-core==2.24.2
google-api-core==2.24.2
# via opencensus
# via opencensus
google-auth==2.40.2
google-auth==2.40.2
# via google-api-core
# via
# databricks-sdk
# google-api-core
googleapis-common-protos==1.70.0
googleapis-common-protos==1.70.0
# via google-api-core
# via google-api-core
graphene==3.4.3
# via mlflow
graphql-core==3.2.6
graphql-core==3.2.6
# via hypothesis-graphql
# via
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
greenlet==3.2.3
# via sqlalchemy
grpcio==1.71.0
grpcio==1.71.0
# via ray
# via ray
gunicorn==23.0.0
# via mlflow
h11==0.14.0
h11==0.14.0
# via httpcore
# via
# httpcore
# uvicorn
h5py==3.13.0
# via terratorch
harfile==0.3.0
harfile==0.3.0
# via schemathesis
# via schemathesis
hf-xet==1.1.3
hf-xet==1.1.3
...
@@ -204,7 +289,7 @@ httpx==0.27.2
...
@@ -204,7 +289,7 @@ httpx==0.27.2
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# schemathesis
# schemathesis
huggingface-hub==0.33.
0
huggingface-hub==0.33.
1
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# accelerate
# accelerate
...
@@ -212,13 +297,19 @@ huggingface-hub==0.33.0
...
@@ -212,13 +297,19 @@ huggingface-hub==0.33.0
# evaluate
# evaluate
# open-clip-torch
# open-clip-torch
# peft
# peft
# segmentation-models-pytorch
# sentence-transformers
# sentence-transformers
# terratorch
# timm
# timm
# tokenizers
# tokenizers
# transformers
# transformers
# vocos
# vocos
humanize==4.11.0
humanize==4.11.0
# via runai-model-streamer
# via runai-model-streamer
hydra-core==1.3.2
# via
# lightly
# lightning
hypothesis==6.131.0
hypothesis==6.131.0
# via
# via
# hypothesis-graphql
# hypothesis-graphql
...
@@ -236,6 +327,14 @@ idna==3.10
...
@@ -236,6 +327,14 @@ idna==3.10
# jsonschema
# jsonschema
# requests
# requests
# yarl
# yarl
imageio==2.37.0
# via scikit-image
importlib-metadata==8.7.0
# via
# mlflow-skinny
# opentelemetry-api
importlib-resources==6.5.2
# via typeshed-client
inflect==5.6.2
inflect==5.6.2
# via datamodel-code-generator
# via datamodel-code-generator
iniconfig==2.0.0
iniconfig==2.0.0
...
@@ -244,9 +343,13 @@ isoduration==20.11.0
...
@@ -244,9 +343,13 @@ isoduration==20.11.0
# via jsonschema
# via jsonschema
isort==5.13.2
isort==5.13.2
# via datamodel-code-generator
# via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6
jinja2==3.1.6
# via
# via
# datamodel-code-generator
# datamodel-code-generator
# flask
# mlflow
# torch
# torch
jiwer==3.0.5
jiwer==3.0.5
# via -r requirements/test.in
# via -r requirements/test.in
...
@@ -259,6 +362,10 @@ joblib==1.4.2
...
@@ -259,6 +362,10 @@ joblib==1.4.2
# librosa
# librosa
# nltk
# nltk
# scikit-learn
# scikit-learn
jsonargparse==4.35.0
# via
# lightning
# terratorch
jsonlines==4.0.0
jsonlines==4.0.0
# via lm-eval
# via lm-eval
jsonpointer==3.0.0
jsonpointer==3.0.0
...
@@ -277,12 +384,33 @@ kaleido==0.2.1
...
@@ -277,12 +384,33 @@ kaleido==0.2.1
# via genai-perf
# via genai-perf
kiwisolver==1.4.7
kiwisolver==1.4.7
# via matplotlib
# via matplotlib
kornia==0.8.1
# via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4
lazy-loader==0.4
# via librosa
# via
# librosa
# scikit-image
libnacl==2.1.0
libnacl==2.1.0
# via tensorizer
# via tensorizer
librosa==0.10.2.post1
librosa==0.10.2.post1
# via -r requirements/test.in
# via -r requirements/test.in
lightly==1.5.20
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.5.1.post0
# via
# terratorch
# torchgeo
lightning-utilities==0.14.3
# via
# lightning
# pytorch-lightning
# torchmetrics
llvmlite==0.44.0
llvmlite==0.44.0
# via numba
# via numba
lm-eval==0.4.8
lm-eval==0.4.8
...
@@ -291,16 +419,27 @@ lxml==5.3.0
...
@@ -291,16 +419,27 @@ lxml==5.3.0
# via
# via
# blobfile
# blobfile
# sacrebleu
# sacrebleu
mako==1.3.10
# via alembic
mamba-ssm==2.2.4
mamba-ssm==2.2.4
# via -r requirements/test.in
# via -r requirements/test.in
markdown==3.8.2
# via mlflow
markdown-it-py==3.0.0
markdown-it-py==3.0.0
# via rich
# via rich
markupsafe==3.0.1
markupsafe==3.0.1
# via
# via
# flask
# jinja2
# jinja2
# mako
# werkzeug
# werkzeug
matplotlib==3.9.2
matplotlib==3.9.2
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
# mlflow
# pycocotools
# torchgeo
mbstrdecoder==1.1.3
mbstrdecoder==1.1.3
# via
# via
# dataproperty
# dataproperty
...
@@ -310,6 +449,10 @@ mdurl==0.1.2
...
@@ -310,6 +449,10 @@ mdurl==0.1.2
# via markdown-it-py
# via markdown-it-py
mistral-common==1.8.0
mistral-common==1.8.0
# via -r requirements/test.in
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0
more-itertools==10.5.0
# via lm-eval
# via lm-eval
mpmath==1.3.0
mpmath==1.3.0
...
@@ -328,10 +471,14 @@ multiprocess==0.70.16
...
@@ -328,10 +471,14 @@ multiprocess==0.70.16
# via
# via
# datasets
# datasets
# evaluate
# evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0
mypy-extensions==1.0.0
# via black
# via black
networkx==3.2.1
networkx==3.2.1
# via torch
# via
# scikit-image
# torch
ninja==1.11.1.3
ninja==1.11.1.3
# via mamba-ssm
# via mamba-ssm
nltk==3.9.1
nltk==3.9.1
...
@@ -348,6 +495,8 @@ numpy==1.26.4
...
@@ -348,6 +495,8 @@ numpy==1.26.4
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# accelerate
# accelerate
# albucore
# albumentations
# bitsandbytes
# bitsandbytes
# bm25s
# bm25s
# contourpy
# contourpy
...
@@ -358,9 +507,15 @@ numpy==1.26.4
...
@@ -358,9 +507,15 @@ numpy==1.26.4
# evaluate
# evaluate
# fastparquet
# fastparquet
# genai-perf
# genai-perf
# geopandas
# h5py
# imageio
# librosa
# librosa
# lightly
# lightly-utils
# matplotlib
# matplotlib
# mistral-common
# mistral-common
# mlflow
# mteb
# mteb
# numba
# numba
# numexpr
# numexpr
...
@@ -368,18 +523,30 @@ numpy==1.26.4
...
@@ -368,18 +523,30 @@ numpy==1.26.4
# pandas
# pandas
# patsy
# patsy
# peft
# peft
# pycocotools
# pyogrio
# rasterio
# rioxarray
# rouge-score
# rouge-score
# runai-model-streamer
# runai-model-streamer
# sacrebleu
# sacrebleu
# scikit-image
# scikit-learn
# scikit-learn
# scipy
# scipy
# segmentation-models-pytorch
# shapely
# soxr
# soxr
# statsmodels
# statsmodels
# tensorboardx
# tensorizer
# tensorizer
# tifffile
# torchgeo
# torchmetrics
# torchvision
# torchvision
# transformers
# transformers
# tritonclient
# tritonclient
# vocos
# vocos
# xarray
nvidia-cublas-cu12==12.8.3.14
nvidia-cublas-cu12==12.8.3.14
# via
# via
# nvidia-cudnn-cu12
# nvidia-cudnn-cu12
...
@@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
...
@@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
# torch
nvidia-nvtx-cu12==12.8.55
nvidia-nvtx-cu12==12.8.55
# via torch
# via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
open-clip-torch==2.32.0
open-clip-torch==2.32.0
# via -r requirements/test.in
# via -r requirements/test.in
opencensus==0.11.4
opencensus==0.11.4
...
@@ -426,7 +597,18 @@ opencensus-context==0.1.3
...
@@ -426,7 +597,18 @@ opencensus-context==0.1.3
opencv-python-headless==4.11.0.86
opencv-python-headless==4.11.0.86
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# albucore
# albumentations
# mistral-common
# mistral-common
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-sdk==1.35.0
# via mlflow-skinny
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2
packaging==24.2
# via
# via
# accelerate
# accelerate
...
@@ -435,26 +617,44 @@ packaging==24.2
...
@@ -435,26 +617,44 @@ packaging==24.2
# datasets
# datasets
# evaluate
# evaluate
# fastparquet
# fastparquet
# geopandas
# gunicorn
# huggingface-hub
# huggingface-hub
# hydra-core
# kornia
# lazy-loader
# lazy-loader
# lightning
# lightning-utilities
# mamba-ssm
# mamba-ssm
# matplotlib
# matplotlib
# mlflow-skinny
# peft
# peft
# plotly
# plotly
# pooch
# pooch
# pyogrio
# pytest
# pytest
# pytest-rerunfailures
# pytest-rerunfailures
# pytorch-lightning
# ray
# ray
# rioxarray
# scikit-image
# statsmodels
# statsmodels
# tensorboardx
# torchmetrics
# transformers
# transformers
# typepy
# typepy
# xarray
pandas==2.2.3
pandas==2.2.3
# via
# via
# datasets
# datasets
# evaluate
# evaluate
# fastparquet
# fastparquet
# genai-perf
# genai-perf
# geopandas
# mlflow
# statsmodels
# statsmodels
# torchgeo
# xarray
pathspec==0.12.1
pathspec==0.12.1
# via black
# via black
pathvalidate==3.2.1
pathvalidate==3.2.1
...
@@ -468,9 +668,14 @@ peft==0.13.2
...
@@ -468,9 +668,14 @@ peft==0.13.2
pillow==10.4.0
pillow==10.4.0
# via
# via
# genai-perf
# genai-perf
# imageio
# lightly-utils
# matplotlib
# matplotlib
# mistral-common
# mistral-common
# scikit-image
# segmentation-models-pytorch
# sentence-transformers
# sentence-transformers
# torchgeo
# torchvision
# torchvision
platformdirs==4.3.6
platformdirs==4.3.6
# via
# via
...
@@ -489,6 +694,8 @@ portalocker==2.10.1
...
@@ -489,6 +694,8 @@ portalocker==2.10.1
# via sacrebleu
# via sacrebleu
pqdm==0.2.0
pqdm==0.2.0
# via -r requirements/test.in
# via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
prometheus-client==0.22.0
# via ray
# via ray
propcache==0.2.0
propcache==0.2.0
...
@@ -499,8 +706,10 @@ protobuf==5.28.3
...
@@ -499,8 +706,10 @@ protobuf==5.28.3
# via
# via
# google-api-core
# google-api-core
# googleapis-common-protos
# googleapis-common-protos
# mlflow-skinny
# proto-plus
# proto-plus
# ray
# ray
# tensorboardx
# tensorizer
# tensorizer
psutil==6.1.0
psutil==6.1.0
# via
# via
...
@@ -515,6 +724,7 @@ pyarrow==18.0.0
...
@@ -515,6 +724,7 @@ pyarrow==18.0.0
# via
# via
# datasets
# datasets
# genai-perf
# genai-perf
# mlflow
pyasn1==0.6.1
pyasn1==0.6.1
# via
# via
# pyasn1-modules
# pyasn1-modules
...
@@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
...
@@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
# via google-auth
# via google-auth
pybind11==2.13.6
pybind11==2.13.6
# via lm-eval
# via lm-eval
pycocotools==2.0.8
# via terratorch
pycountry==24.6.1
pycountry==24.6.1
# via pydantic-extra-types
# via pydantic-extra-types
pycparser==2.22
pycparser==2.22
...
@@ -532,8 +744,12 @@ pycryptodomex==3.22.0
...
@@ -532,8 +744,12 @@ pycryptodomex==3.22.0
pydantic==2.11.5
pydantic==2.11.5
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# albumentations
# datamodel-code-generator
# datamodel-code-generator
# fastapi
# lightly
# mistral-common
# mistral-common
# mlflow-skinny
# mteb
# mteb
# pydantic-extra-types
# pydantic-extra-types
# ray
# ray
...
@@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
...
@@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
# via mistral-common
# via mistral-common
pygments==2.18.0
pygments==2.18.0
# via rich
# via rich
pyogrio==0.11.0
# via geopandas
pyparsing==3.2.0
pyparsing==3.2.0
# via matplotlib
# via
# matplotlib
# rasterio
pyproj==3.7.1
# via
# geopandas
# rioxarray
# torchgeo
pyrate-limiter==3.7.0
pyrate-limiter==3.7.0
# via schemathesis
# via schemathesis
pystemmer==3.0.0
pystemmer==3.0.0
# via mteb
# via mteb
pytablewriter==1.2.0
pytablewriter==1.2.0
# via lm-eval
# via lm-eval
pytest==8.3.
3
pytest==8.3.
5
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# buildkite-test-collector
# buildkite-test-collector
...
@@ -564,6 +789,7 @@ pytest==8.3.3
...
@@ -564,6 +789,7 @@ pytest==8.3.3
# pytest-subtests
# pytest-subtests
# pytest-timeout
# pytest-timeout
# schemathesis
# schemathesis
# terratorch
pytest-asyncio==0.24.0
pytest-asyncio==0.24.0
# via -r requirements/test.in
# via -r requirements/test.in
pytest-forked==1.6.0
pytest-forked==1.6.0
...
@@ -578,15 +804,23 @@ pytest-subtests==0.14.1
...
@@ -578,15 +804,23 @@ pytest-subtests==0.14.1
# via schemathesis
# via schemathesis
pytest-timeout==2.3.1
pytest-timeout==2.3.1
# via -r requirements/test.in
# via -r requirements/test.in
python-box==7.3.2
# via terratorch
python-dateutil==2.9.0.post0
python-dateutil==2.9.0.post0
# via
# via
# arrow
# arrow
# botocore
# botocore
# graphene
# lightly
# matplotlib
# matplotlib
# pandas
# pandas
# typepy
# typepy
python-rapidjson==1.20
python-rapidjson==1.20
# via tritonclient
# via tritonclient
pytorch-lightning==2.5.2
# via
# lightly
# lightning
pytrec-eval-terrier==0.5.7
pytrec-eval-terrier==0.5.7
# via mteb
# via mteb
pytz==2024.2
pytz==2024.2
...
@@ -596,11 +830,17 @@ pytz==2024.2
...
@@ -596,11 +830,17 @@ pytz==2024.2
pyyaml==6.0.2
pyyaml==6.0.2
# via
# via
# accelerate
# accelerate
# albumentations
# datamodel-code-generator
# datamodel-code-generator
# datasets
# datasets
# genai-perf
# genai-perf
# huggingface-hub
# huggingface-hub
# jsonargparse
# lightning
# mlflow-skinny
# omegaconf
# peft
# peft
# pytorch-lightning
# ray
# ray
# responses
# responses
# schemathesis
# schemathesis
...
@@ -609,6 +849,11 @@ pyyaml==6.0.2
...
@@ -609,6 +849,11 @@ pyyaml==6.0.2
# vocos
# vocos
rapidfuzz==3.12.1
rapidfuzz==3.12.1
# via jiwer
# via jiwer
rasterio==1.4.3
# via
# rioxarray
# terratorch
# torchgeo
ray==2.43.0
ray==2.43.0
# via -r requirements/test.in
# via -r requirements/test.in
redis==5.2.0
redis==5.2.0
...
@@ -627,12 +872,16 @@ regex==2024.9.11
...
@@ -627,12 +872,16 @@ regex==2024.9.11
requests==2.32.3
requests==2.32.3
# via
# via
# buildkite-test-collector
# buildkite-test-collector
# databricks-sdk
# datasets
# datasets
# docker
# evaluate
# evaluate
# google-api-core
# google-api-core
# huggingface-hub
# huggingface-hub
# lightly
# lm-eval
# lm-eval
# mistral-common
# mistral-common
# mlflow-skinny
# mteb
# mteb
# pooch
# pooch
# ray
# ray
...
@@ -650,8 +899,11 @@ rfc3987==1.3.8
...
@@ -650,8 +899,11 @@ rfc3987==1.3.8
rich==13.9.4
rich==13.9.4
# via
# via
# genai-perf
# genai-perf
# lightning
# mteb
# mteb
# typer
# typer
rioxarray==0.19.0
# via terratorch
rouge-score==0.1.2
rouge-score==0.1.2
# via lm-eval
# via lm-eval
rpds-py==0.20.1
rpds-py==0.20.1
...
@@ -660,6 +912,8 @@ rpds-py==0.20.1
...
@@ -660,6 +912,8 @@ rpds-py==0.20.1
# referencing
# referencing
rsa==4.9.1
rsa==4.9.1
# via google-auth
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.11.0
runai-model-streamer==0.11.0
# via -r requirements/test.in
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
runai-model-streamer-s3==0.11.0
...
@@ -677,21 +931,32 @@ safetensors==0.4.5
...
@@ -677,21 +931,32 @@ safetensors==0.4.5
# transformers
# transformers
schemathesis==3.39.15
schemathesis==3.39.15
# via -r requirements/test.in
# via -r requirements/test.in
scikit-image==0.25.2
# via albumentations
scikit-learn==1.5.2
scikit-learn==1.5.2
# via
# via
# albumentations
# librosa
# librosa
# lm-eval
# lm-eval
# mlflow
# mteb
# mteb
# sentence-transformers
# sentence-transformers
scipy==1.13.1
scipy==1.13.1
# via
# via
# albumentations
# bm25s
# bm25s
# librosa
# librosa
# mlflow
# mteb
# mteb
# scikit-image
# scikit-learn
# scikit-learn
# sentence-transformers
# sentence-transformers
# statsmodels
# statsmodels
# vocos
# vocos
segmentation-models-pytorch==0.4.0
# via
# terratorch
# torchgeo
sentence-transformers==3.2.1
sentence-transformers==3.2.1
# via
# via
# -r requirements/test.in
# -r requirements/test.in
...
@@ -700,21 +965,30 @@ sentencepiece==0.2.0
...
@@ -700,21 +965,30 @@ sentencepiece==0.2.0
# via mistral-common
# via mistral-common
setuptools==77.0.3
setuptools==77.0.3
# via
# via
# lightning-utilities
# mamba-ssm
# mamba-ssm
# pytablewriter
# pytablewriter
# torch
# torch
# triton
# triton
shapely==2.1.1
# via
# geopandas
# torchgeo
shellingham==1.5.4
shellingham==1.5.4
# via typer
# via typer
six==1.16.0
six==1.16.0
# via
# via
# junit-xml
# junit-xml
# lightly
# opencensus
# opencensus
# python-dateutil
# python-dateutil
# rfc3339-validator
# rfc3339-validator
# rouge-score
# rouge-score
# segmentation-models-pytorch
smart-open==7.1.0
smart-open==7.1.0
# via ray
# via ray
smmap==5.0.2
# via gitdb
sniffio==1.3.1
sniffio==1.3.1
# via
# via
# anyio
# anyio
...
@@ -727,10 +1001,17 @@ soundfile==0.12.1
...
@@ -727,10 +1001,17 @@ soundfile==0.12.1
# librosa
# librosa
soxr==0.5.0.post1
soxr==0.5.0.post1
# via librosa
# via librosa
sqlalchemy==2.0.41
# via
# alembic
# mlflow
sqlitedict==2.1.0
sqlitedict==2.1.0
# via lm-eval
# via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.46.2
starlette==0.46.2
# via
# via
# fastapi
# schemathesis
# schemathesis
# starlette-testclient
# starlette-testclient
starlette-testclient==0.4.1
starlette-testclient==0.4.1
...
@@ -751,18 +1032,29 @@ tenacity==9.0.0
...
@@ -751,18 +1032,29 @@ tenacity==9.0.0
# via
# via
# lm-eval
# lm-eval
# plotly
# plotly
tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
tensorizer==2.10.1
# via -r requirements/test.in
# via -r requirements/test.in
terratorch==1.1rc2
# via -r requirements/test.in
threadpoolctl==3.5.0
threadpoolctl==3.5.0
# via scikit-learn
# via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
# terratorch
tiktoken==0.7.0
tiktoken==0.7.0
# via
# via
# lm-eval
# lm-eval
# mistral-common
# mistral-common
timm==1.0.1
1
timm==1.0.1
5
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# open-clip-torch
# open-clip-torch
# segmentation-models-pytorch
# terratorch
# torchgeo
tokenizers==0.21.1
tokenizers==0.21.1
# via
# via
# -r requirements/test.in
# -r requirements/test.in
...
@@ -776,18 +1068,28 @@ torch==2.7.1+cu128
...
@@ -776,18 +1068,28 @@ torch==2.7.1+cu128
# -r requirements/test.in
# -r requirements/test.in
# accelerate
# accelerate
# bitsandbytes
# bitsandbytes
# efficientnet-pytorch
# encodec
# encodec
# fastsafetensors
# fastsafetensors
# kornia
# lightly
# lightning
# lm-eval
# lm-eval
# mamba-ssm
# mamba-ssm
# mteb
# mteb
# open-clip-torch
# open-clip-torch
# peft
# peft
# pretrainedmodels
# pytorch-lightning
# runai-model-streamer
# runai-model-streamer
# segmentation-models-pytorch
# sentence-transformers
# sentence-transformers
# tensorizer
# tensorizer
# terratorch
# timm
# timm
# torchaudio
# torchaudio
# torchgeo
# torchmetrics
# torchvision
# torchvision
# vector-quantize-pytorch
# vector-quantize-pytorch
# vocos
# vocos
...
@@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128
...
@@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128
# -r requirements/test.in
# -r requirements/test.in
# encodec
# encodec
# vocos
# vocos
torchgeo==0.7.0
# via terratorch
torchmetrics==1.7.4
# via
# lightning
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.22.1+cu128
torchvision==0.22.1+cu128
# via
# via
# -r requirements/test.in
# -r requirements/test.in
# lightly
# open-clip-torch
# open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch
# terratorch
# timm
# timm
# torchgeo
tqdm==4.66.6
tqdm==4.66.6
# via
# via
# datasets
# datasets
# evaluate
# evaluate
# huggingface-hub
# huggingface-hub
# lightly
# lightning
# lm-eval
# lm-eval
# mteb
# mteb
# nltk
# nltk
# open-clip-torch
# open-clip-torch
# peft
# peft
# pqdm
# pqdm
# pretrainedmodels
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers
# sentence-transformers
# tqdm-multiprocess
# tqdm-multiprocess
# transformers
# transformers
...
@@ -843,18 +1163,34 @@ typer==0.15.2
...
@@ -843,18 +1163,34 @@ typer==0.15.2
# via fastsafetensors
# via fastsafetensors
types-python-dateutil==2.9.0.20241206
types-python-dateutil==2.9.0.20241206
# via arrow
# via arrow
typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.12.2
typing-extensions==4.12.2
# via
# via
# albumentations
# alembic
# fastapi
# graphene
# huggingface-hub
# huggingface-hub
# librosa
# librosa
# lightning
# lightning-utilities
# mistral-common
# mistral-common
# mlflow-skinny
# mteb
# mteb
# opentelemetry-api
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pqdm
# pqdm
# pydantic
# pydantic
# pydantic-core
# pydantic-core
# pydantic-extra-types
# pydantic-extra-types
# pytorch-lightning
# sqlalchemy
# torch
# torch
# torchgeo
# typer
# typer
# typeshed-client
# typing-inspection
# typing-inspection
typing-inspection==0.4.1
typing-inspection==0.4.1
# via pydantic
# via pydantic
...
@@ -866,9 +1202,13 @@ urllib3==2.2.3
...
@@ -866,9 +1202,13 @@ urllib3==2.2.3
# via
# via
# blobfile
# blobfile
# botocore
# botocore
# docker
# lightly
# requests
# requests
# responses
# responses
# tritonclient
# tritonclient
uvicorn==0.35.0
# via mlflow-skinny
vector-quantize-pytorch==1.21.2
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
# via -r requirements/test.in
virtualenv==20.31.2
virtualenv==20.31.2
...
@@ -880,11 +1220,15 @@ wcwidth==0.2.13
...
@@ -880,11 +1220,15 @@ wcwidth==0.2.13
webcolors==24.11.1
webcolors==24.11.1
# via jsonschema
# via jsonschema
werkzeug==3.1.3
werkzeug==3.1.3
# via schemathesis
# via
# flask
# schemathesis
word2number==1.1
word2number==1.1
# via lm-eval
# via lm-eval
wrapt==1.17.2
wrapt==1.17.2
# via smart-open
# via smart-open
xarray==2025.7.1
# via rioxarray
xxhash==3.5.0
xxhash==3.5.0
# via
# via
# datasets
# datasets
...
@@ -893,5 +1237,7 @@ yarl==1.17.1
...
@@ -893,5 +1237,7 @@ yarl==1.17.1
# via
# via
# aiohttp
# aiohttp
# schemathesis
# schemathesis
zipp==3.23.0
# via importlib-metadata
zstandard==0.23.0
zstandard==0.23.0
# via lm-eval
# via lm-eval
tests/models/multimodal/pooling/test_prithvi_mae.py
0 → 100644
View file @
8560a5b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.utils
import
set_default_torch_num_threads
from
....conftest
import
VllmRunner
def
generate_test_mm_data
():
mm_data
=
{
"pixel_values"
:
torch
.
full
((
6
,
512
,
512
),
1.0
,
dtype
=
torch
.
float16
),
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
,
dtype
=
torch
.
float16
),
}
return
mm_data
def
_run_test
(
vllm_runner
:
type
[
VllmRunner
],
model
:
str
,
)
->
None
:
prompt
=
[
{
# This model deals with no text input
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
generate_test_mm_data
(),
}
for
_
in
range
(
10
)
]
with
(
set_default_torch_num_threads
(
1
),
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
torch
.
float16
,
enforce_eager
=
True
,
skip_tokenizer_init
=
True
,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs
=
32
,
)
as
vllm_model
,
):
vllm_model
.
encode
(
prompt
)
MODELS
=
[
"christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
]
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
def
test_models_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
)
->
None
:
_run_test
(
vllm_runner
,
model
,
)
vllm/config.py
View file @
8560a5b2
...
@@ -651,6 +651,8 @@ class ModelConfig:
...
@@ -651,6 +651,8 @@ class ModelConfig:
self
.
original_max_model_len
=
self
.
max_model_len
self
.
original_max_model_len
=
self
.
max_model_len
self
.
max_model_len
=
self
.
get_and_verify_max_len
(
self
.
max_model_len
)
self
.
max_model_len
=
self
.
get_and_verify_max_len
(
self
.
max_model_len
)
self
.
multimodal_config
=
self
.
_init_multimodal_config
()
self
.
multimodal_config
=
self
.
_init_multimodal_config
()
self
.
model_supports_multimodal_raw_input
=
(
self
.
registry
.
supports_multimodal_raw_input
(
self
.
architectures
))
if
not
self
.
skip_tokenizer_init
:
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
...
@@ -1243,10 +1245,10 @@ class ModelConfig:
...
@@ -1243,10 +1245,10 @@ class ModelConfig:
return
self
.
get_hf_config_sliding_window
()
return
self
.
get_hf_config_sliding_window
()
def
get_vocab_size
(
self
)
->
int
:
def
get_vocab_size
(
self
)
->
int
:
return
self
.
hf_text_config
.
vocab_size
return
getattr
(
self
.
hf_text_config
,
"
vocab_size
"
,
0
)
def
get_hidden_size
(
self
)
->
int
:
def
get_hidden_size
(
self
)
->
int
:
return
self
.
hf_text_config
.
hidden_size
return
getattr
(
self
.
hf_text_config
,
"
hidden_size
"
,
0
)
@
property
@
property
def
is_deepseek_mla
(
self
)
->
bool
:
def
is_deepseek_mla
(
self
)
->
bool
:
...
...
vllm/engine/llm_engine.py
View file @
8560a5b2
...
@@ -238,14 +238,14 @@ class LLMEngine:
...
@@ -238,14 +238,14 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
detokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
tokenizer_group
=
None
else
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
# Ensure that the function doesn't contain a reference to self,
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
# to avoid engine GC issues
...
...
vllm/model_executor/models/interfaces.py
View file @
8560a5b2
...
@@ -136,6 +136,40 @@ def supports_multimodal(
...
@@ -136,6 +136,40 @@ def supports_multimodal(
return
getattr
(
model
,
"supports_multimodal"
,
False
)
return
getattr
(
model
,
"supports_multimodal"
,
False
)
@
runtime_checkable
class
SupportsMultiModalWithRawInput
(
SupportsMultiModal
,
Protocol
):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@
overload
def
supports_multimodal_raw_input
(
model
:
object
)
->
TypeIs
[
SupportsMultiModalWithRawInput
]:
...
@
overload
def
supports_multimodal_raw_input
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsMultiModalWithRawInput
]]:
...
def
supports_multimodal_raw_input
(
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
SupportsMultiModalWithRawInput
]],
TypeIs
[
SupportsMultiModalWithRawInput
]]:
return
getattr
(
model
,
"supports_multimodal_raw_input"
,
False
)
@
runtime_checkable
@
runtime_checkable
class
SupportsScoreTemplate
(
Protocol
):
class
SupportsScoreTemplate
(
Protocol
):
"""The interface required for all models that support score template."""
"""The interface required for all models that support score template."""
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
8560a5b2
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
...
@@ -27,13 +28,14 @@ from vllm.config import VllmConfig
...
@@ -27,13 +28,14 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.pooler
import
(
AllPool
,
PoolerHead
,
from
vllm.model_executor.layers.pooler
import
(
AllPool
,
PoolerHead
,
PoolerIdentity
,
SimplePooler
)
PoolerIdentity
,
SimplePooler
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
IsAttentionFree
,
MultiModalEmbeddings
,
SupportsMultiModalWithRawInput
)
SupportsV0Only
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
)
MultiModalFieldElem
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
,
PlaceholderRange
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptUpdate
)
BaseProcessingInfo
,
PromptUpdate
)
...
@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
...
@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
# The size of pixel_values might change in the cases where we resize
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
# the input but never exceeds the dimensions below.
return
{
return
{
"pixel_values"
:
torch
.
full
((
1
,
6
,
512
,
512
),
1.0
),
"pixel_values"
:
torch
.
full
((
6
,
512
,
512
),
1.0
,
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
),
dtype
=
torch
.
float16
),
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
,
dtype
=
torch
.
float16
),
}
}
...
@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
shared
(
batch_size
=
1
,
location_coords
=
MultiModalFieldConfig
.
batched
(
"image"
),
modality
=
"image"
),
location_coords
=
MultiModalFieldConfig
.
shared
(
batch_size
=
1
,
modality
=
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
for
k
,
v
in
mm_data
.
items
():
for
k
,
v
in
mm_data
.
items
():
mm_kwargs
[
k
]
=
v
mm_kwargs
[
k
]
=
v
mm_placeholders
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
0
)]}
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
multimodal_kwargs_items
=
[
MultiModalKwargsItem
.
from_elems
([
MultiModalFieldElem
(
modality
=
"image"
,
key
=
key
,
data
=
data
,
field
=
MultiModalSharedField
(
1
),
)
for
key
,
data
in
mm_kwargs
.
items
()
])
]
return
MultiModalInputs
(
return
MultiModalInputs
(
type
=
"multimodal"
,
type
=
"multimodal"
,
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
[
1
],
prompt_token_ids
=
[
1
],
mm_kwargs
=
MultiModalKwargs
(
mm_kwarg
s
),
mm_kwargs
=
MultiModalKwargs
.
from_items
(
multimodal_kwargs_item
s
),
mm_hashes
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
{}
,
mm_placeholders
=
mm_placeholders
,
)
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
@
MULTIMODAL_REGISTRY
.
register_processor
(
PrithviGeoSpatialMAEMultiModalProcessor
,
PrithviGeoSpatialMAEMultiModalProcessor
,
info
=
PrithviGeoSpatialMAEProcessingInfo
,
info
=
PrithviGeoSpatialMAEProcessingInfo
,
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
,
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
,
)
SupportsV0Only
):
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModalWithRawInput
):
"""Prithvi Masked Autoencoder"""
"""Prithvi Masked Autoencoder"""
is_pooling_model
=
True
is_pooling_model
=
True
...
@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
raise
ValueError
(
"Only image modality is supported"
)
raise
ValueError
(
"Only image modality is supported"
)
def
_instantiate_model
(
self
,
config
:
dict
)
->
Optional
[
nn
.
Module
]:
def
_instantiate_model
(
self
,
config
:
dict
)
->
Optional
[
nn
.
Module
]:
# We might be able/need to support different tasks with this same model
# We might be able/need to support different tasks with this same model
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
from
terratorch.cli_tools
import
SemanticSegmentationTask
from
terratorch.cli_tools
import
SemanticSegmentationTask
task
=
SemanticSegmentationTask
(
task
=
SemanticSegmentationTask
(
config
[
"model_args"
],
config
[
"model_args"
],
config
[
"task_args"
][
"model_factory"
],
config
[
"task_args"
][
"model_factory"
],
...
@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
scheduler_hparams
=
config
[
"scheduler_params"
],
scheduler_hparams
=
config
[
"scheduler_params"
],
plot_on_val
=
config
[
"task_args"
][
"plot_on_val"
],
plot_on_val
=
config
[
"task_args"
][
"plot_on_val"
],
freeze_decoder
=
config
[
"task_args"
][
"freeze_decoder"
],
freeze_decoder
=
config
[
"task_args"
][
"freeze_decoder"
],
freeze_backbone
=
config
[
"task_args"
][
"freeze_backbone"
])
freeze_backbone
=
config
[
"task_args"
][
"freeze_backbone"
],
)
return
task
.
model
return
task
.
model
else
:
else
:
...
@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
def
_parse_and_validate_multimodal_data
(
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
self
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
f
"Incorrect type of pixel_values. "
raise
ValueError
(
f
"Incorrect type of pixel_values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_values
=
torch
.
unbind
(
pixel_values
,
dim
=
0
)[
0
]
location_coords
=
kwargs
.
pop
(
"location_coords"
,
None
)
location_coords
=
kwargs
.
pop
(
"location_coords"
,
None
)
if
not
isinstance
(
location_coords
,
torch
.
Tensor
):
if
not
isinstance
(
location_coords
,
torch
.
Tensor
):
...
@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return
pixel_values
,
location_coords
return
pixel_values
,
location_coords
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return
torch
.
empty
((
input_ids
.
shape
[
0
],
0
))
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/models/registry.py
View file @
8560a5b2
...
@@ -22,8 +22,8 @@ from vllm.logger import init_logger
...
@@ -22,8 +22,8 @@ from vllm.logger import init_logger
from
.interfaces
import
(
has_inner_state
,
has_noops
,
is_attention_free
,
from
.interfaces
import
(
has_inner_state
,
has_noops
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_
pp
,
supports_multimodal
,
supports_
multimodal_raw_input
,
supports_transcription
,
supports_v0_only
)
supports_pp
,
supports_transcription
,
supports_v0_only
)
from
.interfaces_base
import
is_text_generation_model
from
.interfaces_base
import
is_text_generation_model
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -287,6 +287,7 @@ class _ModelInfo:
...
@@ -287,6 +287,7 @@ class _ModelInfo:
is_pooling_model
:
bool
is_pooling_model
:
bool
supports_cross_encoding
:
bool
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_multimodal
:
bool
supports_multimodal_raw_input
:
bool
supports_pp
:
bool
supports_pp
:
bool
has_inner_state
:
bool
has_inner_state
:
bool
is_attention_free
:
bool
is_attention_free
:
bool
...
@@ -304,6 +305,7 @@ class _ModelInfo:
...
@@ -304,6 +305,7 @@ class _ModelInfo:
is_pooling_model
=
True
,
# Can convert any model into a pooling model
is_pooling_model
=
True
,
# Can convert any model into a pooling model
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_multimodal_raw_input
=
supports_multimodal_raw_input
(
model
),
supports_pp
=
supports_pp
(
model
),
supports_pp
=
supports_pp
(
model
),
has_inner_state
=
has_inner_state
(
model
),
has_inner_state
=
has_inner_state
(
model
),
is_attention_free
=
is_attention_free
(
model
),
is_attention_free
=
is_attention_free
(
model
),
...
@@ -573,6 +575,13 @@ class _ModelRegistry:
...
@@ -573,6 +575,13 @@ class _ModelRegistry:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_multimodal
return
model_cls
.
supports_multimodal
def
supports_multimodal_raw_input
(
self
,
architectures
:
Union
[
str
,
list
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_multimodal_raw_input
def
is_pp_supported_model
(
def
is_pp_supported_model
(
self
,
self
,
architectures
:
Union
[
str
,
list
[
str
]],
architectures
:
Union
[
str
,
list
[
str
]],
...
...
vllm/multimodal/registry.py
View file @
8560a5b2
...
@@ -266,7 +266,7 @@ class MultiModalRegistry:
...
@@ -266,7 +266,7 @@ class MultiModalRegistry:
if
not
model_config
.
is_multimodal_model
:
if
not
model_config
.
is_multimodal_model
:
raise
ValueError
(
f
"
{
model_config
.
model
}
is not a multimodal model"
)
raise
ValueError
(
f
"
{
model_config
.
model
}
is not a multimodal model"
)
if
tokenizer
is
None
:
if
tokenizer
is
None
and
not
model_config
.
skip_tokenizer_init
:
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
if
disable_cache
is
None
:
if
disable_cache
is
None
:
mm_config
=
model_config
.
get_multimodal_config
()
mm_config
=
model_config
.
get_multimodal_config
()
...
...
vllm/v1/engine/async_llm.py
View file @
8560a5b2
...
@@ -94,11 +94,14 @@ class AsyncLLM(EngineClient):
...
@@ -94,11 +94,14 @@ class AsyncLLM(EngineClient):
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
# Tokenizer (+ ensure liveness if running in another process).
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
init_tokenizer_from_configs
(
self
.
tokenizer
=
None
model_config
=
vllm_config
.
model_config
,
else
:
scheduler_config
=
vllm_config
.
scheduler_config
,
# Tokenizer (+ ensure liveness if running in another process).
lora_config
=
vllm_config
.
lora_config
)
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
# Processor (converts Inputs --> EngineCoreRequests).
# Processor (converts Inputs --> EngineCoreRequests).
self
.
processor
=
Processor
(
self
.
processor
=
Processor
(
...
@@ -525,6 +528,10 @@ class AsyncLLM(EngineClient):
...
@@ -525,6 +528,10 @@ class AsyncLLM(EngineClient):
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
)
->
AnyTokenizer
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"Unable to get tokenizer because "
"skip_tokenizer_init is True"
)
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
async
def
is_tracing_enabled
(
self
)
->
bool
:
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
vllm/v1/engine/llm_engine.py
View file @
8560a5b2
...
@@ -82,11 +82,14 @@ class LLMEngine:
...
@@ -82,11 +82,14 @@ class LLMEngine:
self
.
dp_group
=
None
self
.
dp_group
=
None
self
.
should_execute_dummy_batch
=
False
self
.
should_execute_dummy_batch
=
False
# Tokenizer (+ ensure liveness if running in another process).
if
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
init_tokenizer_from_configs
(
self
.
tokenizer
=
None
model_config
=
vllm_config
.
model_config
,
else
:
scheduler_config
=
vllm_config
.
scheduler_config
,
# Tokenizer (+ ensure liveness if running in another process).
lora_config
=
vllm_config
.
lora_config
)
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
# Processor (convert Inputs --> EngineCoreRequests)
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
self
.
processor
=
Processor
(
vllm_config
=
vllm_config
,
...
...
vllm/v1/engine/output_processor.py
View file @
8560a5b2
...
@@ -327,14 +327,16 @@ class OutputProcessor:
...
@@ -327,14 +327,16 @@ class OutputProcessor:
if
request_id
in
self
.
request_states
:
if
request_id
in
self
.
request_states
:
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
req_state
=
RequestState
.
from_new_request
(
tokenizer
=
None
if
not
self
.
tokenizer
else
\
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
request
.
lora_request
),
self
.
tokenizer
.
get_lora_tokenizer
(
request
.
lora_request
)
request
=
request
,
prompt
=
prompt
,
req_state
=
RequestState
.
from_new_request
(
tokenizer
=
tokenizer
,
parent_req
=
parent_req
,
request
=
request
,
request_index
=
request_index
,
prompt
=
prompt
,
queue
=
queue
,
parent_req
=
parent_req
,
log_stats
=
self
.
log_stats
)
request_index
=
request_index
,
queue
=
queue
,
log_stats
=
self
.
log_stats
)
self
.
request_states
[
request_id
]
=
req_state
self
.
request_states
[
request_id
]
=
req_state
self
.
lora_states
.
add_request
(
req_state
)
self
.
lora_states
.
add_request
(
req_state
)
if
parent_req
:
if
parent_req
:
...
...
vllm/v1/engine/processor.py
View file @
8560a5b2
...
@@ -380,7 +380,6 @@ class Processor:
...
@@ -380,7 +380,6 @@ class Processor:
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
not
prompt_ids
:
if
not
prompt_ids
:
...
@@ -389,9 +388,14 @@ class Processor:
...
@@ -389,9 +388,14 @@ class Processor:
else
:
else
:
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
self
.
model_config
.
skip_tokenizer_init
:
if
max_input_id
>
tokenizer
.
max_token_id
:
tokenizer
=
None
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
else
:
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
max_prompt_len
=
self
.
model_config
.
max_model_len
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>
max_prompt_len
:
if
len
(
prompt_ids
)
>
max_prompt_len
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
8560a5b2
...
@@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
is_multimodal_model
=
model_config
.
is_multimodal_model
self
.
is_pooling_model
=
model_config
.
pooler_config
is
not
None
self
.
is_pooling_model
=
model_config
.
pooler_config
is
not
None
self
.
model_supports_multimodal_raw_input
=
(
model_config
.
model_supports_multimodal_raw_input
)
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
...
@@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
Args:
scheduler_output: The scheduler output.
scheduler_output: The scheduler output.
"""
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
0
:
return
self
.
attn_metadata_builders
[
0
].
reorder_batch
(
self
.
input_batch
,
self
.
attn_metadata_builders
[
0
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
scheduler_output
)
...
@@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates.
# Refresh batch metadata with any pending updates.
self
.
input_batch
.
refresh_metadata
()
self
.
input_batch
.
refresh_metadata
()
def
_init_model_kwargs_for_multimodal_model
(
self
,
scheduler_output
:
Optional
[
"SchedulerOutput"
]
=
None
,
num_reqs
:
int
=
-
1
,
)
->
dict
[
str
,
Any
]:
model_kwargs
:
dict
[
str
,
Any
]
=
{}
if
self
.
model_supports_multimodal_raw_input
:
# This model requires the raw multimodal data in input.
if
scheduler_output
:
multi_modal_kwargs_list
=
[]
for
req
in
scheduler_output
.
scheduled_new_reqs
:
req_mm_inputs
=
req
.
mm_inputs
if
not
isinstance
(
req_mm_inputs
,
list
):
req_mm_inputs
=
list
(
req_mm_inputs
)
multi_modal_kwargs_list
.
extend
(
req_mm_inputs
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
else
:
# The only case where SchedulerOutput is None is for
# a dummy run let's get some dummy data.
dummy_data
=
[
self
.
mm_registry
.
get_decoder_dummy_data
(
model_config
=
self
.
model_config
,
seq_len
=
1
).
multi_modal_data
for
i
in
range
(
num_reqs
)
]
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
dummy_data
)
model_kwargs
.
update
(
multi_modal_kwargs
)
return
model_kwargs
def
_get_cumsum_and_arange
(
def
_get_cumsum_and_arange
(
self
,
self
,
num_tokens
:
np
.
ndarray
,
num_tokens
:
np
.
ndarray
,
...
@@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
model_kwargs
=
self
.
_init_model_kwargs_for_multimodal_model
(
scheduler_output
=
scheduler_output
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
multimodal_embeddings
=
mm_embeds
or
None
,
multimodal_embeddings
=
mm_embeds
or
None
,
)
)
# TODO(woosuk): Avoid the copy. Optimize.
# TODO(woosuk): Avoid the copy. Optimize.
self
.
inputs_embeds
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds
)
self
.
inputs_embeds
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds
)
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
...
@@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the CUDA graph.
# then the embedding layer is not included in the CUDA graph.
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
inputs_embeds
=
None
model_kwargs
=
{}
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_input_tokens
]
positions
=
self
.
mrope_positions
[:,
:
num_input_tokens
]
else
:
else
:
...
@@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
positions
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
**
MultiModalKwargs
.
as_kwargs
(
model_kwargs
,
device
=
self
.
device
,
),
)
)
self
.
maybe_wait_for_kv_save
()
self
.
maybe_wait_for_kv_save
()
...
@@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens
):
num_scheduled_tokens
):
model
=
self
.
model
model
=
self
.
model
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
model_kwargs
=
self
.
_init_model_kwargs_for_multimodal_model
(
num_reqs
=
num_reqs
)
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
else
:
else
:
input_ids
=
self
.
input_ids
[:
num_tokens
]
input_ids
=
self
.
input_ids
[:
num_tokens
]
inputs_embeds
=
None
inputs_embeds
=
None
model_kwargs
=
{}
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
else
:
else
:
...
@@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions
=
positions
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
**
MultiModalKwargs
.
as_kwargs
(
model_kwargs
,
device
=
self
.
device
,
),
)
)
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
_
=
outputs
hidden_states
,
_
=
outputs
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment