Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
974dfd49
Unverified
Commit
974dfd49
authored
Feb 12, 2025
by
Christian Pinto
Committed by
GitHub
Feb 11, 2025
Browse files
[Model] IBM/NASA Prithvi Geospatial model (#12830)
parent
3ee696a6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
811 additions
and
9 deletions
+811
-9
examples/offline_inference/prithvi_geospatial_mae.py
examples/offline_inference/prithvi_geospatial_mae.py
+530
-0
tests/models/registry.py
tests/models/registry.py
+4
-0
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+8
-3
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+17
-5
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+238
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+4
-0
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+10
-1
No files found.
examples/offline_inference/prithvi_geospatial_mae.py
0 → 100644
View file @
974dfd49
# SPDX-License-Identifier: Apache-2.0
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
"""
# noqa: E501
import
argparse
import
datetime
import
os
import
re
from
typing
import
List
,
Union
import
albumentations
import
numpy
as
np
import
rasterio
import
torch
from
einops
import
rearrange
from
terratorch.datamodules
import
Sen1Floods11NonGeoDataModule
from
vllm
import
LLM
NO_DATA
=
-
9999
NO_DATA_FLOAT
=
0.0001
OFFSET
=
0
PERCENTILE
=
99
model_config
=
"""{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"./model/config.json"
),
'w'
)
as
config_file
:
config_file
.
write
(
model_config
)
datamodule_config
=
{
'bands'
:
[
'BLUE'
,
'GREEN'
,
'RED'
,
'NIR_NARROW'
,
'SWIR_1'
,
'SWIR_2'
],
'batch_size'
:
16
,
'constant_scale'
:
0.0001
,
'data_root'
:
'/dccstor/geofm-finetuning/datasets/sen1floods11'
,
'drop_last'
:
True
,
'no_data_replace'
:
0.0
,
'no_label_replace'
:
-
1
,
'num_workers'
:
8
,
'test_transform'
:
[
albumentations
.
Resize
(
always_apply
=
False
,
height
=
448
,
interpolation
=
1
,
p
=
1
,
width
=
448
),
albumentations
.
pytorch
.
ToTensorV2
(
transpose_mask
=
False
,
always_apply
=
True
,
p
=
1.0
)
],
}
class
PrithviMAE
:
def
__init__
(
self
):
print
(
"Initializing PrithviMAE model"
)
self
.
model
=
LLM
(
model
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"./model"
),
skip_tokenizer_init
=
True
,
dtype
=
"float32"
)
def
run
(
self
,
input_data
,
location_coords
):
print
(
"################ Running inference on vLLM ##############"
)
# merge the inputs into one data structure
mm_data
=
{
"pixel_values"
:
torch
.
empty
(
0
)
if
input_data
is
None
else
input_data
,
"location_coords"
:
torch
.
empty
(
0
)
if
location_coords
is
None
else
location_coords
}
prompt
=
{
"prompt_token_ids"
:
[
1
],
"multi_modal_data"
:
mm_data
}
outputs
=
self
.
model
.
encode
(
prompt
,
use_tqdm
=
False
)
print
(
"################ Inference done (it took seconds) ##############"
)
return
outputs
[
0
].
outputs
.
data
def
generate_datamodule
():
datamodule
=
Sen1Floods11NonGeoDataModule
(
data_root
=
datamodule_config
[
'data_root'
],
batch_size
=
datamodule_config
[
"batch_size"
],
num_workers
=
datamodule_config
[
"num_workers"
],
bands
=
datamodule_config
[
"bands"
],
drop_last
=
datamodule_config
[
"drop_last"
],
test_transform
=
datamodule_config
[
"test_transform"
""
])
return
datamodule
def
process_channel_group
(
orig_img
,
channels
):
"""
Args:
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
channels: list of indices representing RGB channels.
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
"""
orig_img
=
orig_img
[
channels
,
...]
valid_mask
=
torch
.
ones_like
(
orig_img
,
dtype
=
torch
.
bool
)
valid_mask
[
orig_img
==
NO_DATA_FLOAT
]
=
False
# Rescale (enhancing contrast)
max_value
=
max
(
3000
,
np
.
percentile
(
orig_img
[
valid_mask
],
PERCENTILE
))
min_value
=
OFFSET
orig_img
=
torch
.
clamp
((
orig_img
-
min_value
)
/
(
max_value
-
min_value
),
0
,
1
)
# No data as zeros
orig_img
[
~
valid_mask
]
=
0
return
orig_img
def
read_geotiff
(
file_path
:
str
):
"""Read all bands from *file_path* and return image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
with
rasterio
.
open
(
file_path
)
as
src
:
img
=
src
.
read
()
meta
=
src
.
meta
try
:
coords
=
src
.
lnglat
()
except
Exception
:
# Cannot read coords
coords
=
None
return
img
,
meta
,
coords
def
save_geotiff
(
image
,
output_path
:
str
,
meta
:
dict
):
"""Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
with
rasterio
.
open
(
output_path
,
"w"
,
**
meta
)
as
dest
:
for
i
in
range
(
image
.
shape
[
0
]):
dest
.
write
(
image
[
i
,
:,
:],
i
+
1
)
return
def
_convert_np_uint8
(
float_image
:
torch
.
Tensor
):
image
=
float_image
.
numpy
()
*
255.0
image
=
image
.
astype
(
dtype
=
np
.
uint8
)
return
image
def
load_example
(
file_paths
:
List
[
str
],
mean
:
List
[
float
]
=
None
,
std
:
List
[
float
]
=
None
,
indices
:
Union
[
list
[
int
],
None
]
=
None
,
):
"""Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the images
in *file_paths*.
std: list containing std values for each band in the images
in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs
=
[]
metas
=
[]
temporal_coords
=
[]
location_coords
=
[]
for
file
in
file_paths
:
img
,
meta
,
coords
=
read_geotiff
(
file
)
# Rescaling (don't normalize on nodata)
img
=
np
.
moveaxis
(
img
,
0
,
-
1
)
# channels last for rescaling
if
indices
is
not
None
:
img
=
img
[...,
indices
]
if
mean
is
not
None
and
std
is
not
None
:
img
=
np
.
where
(
img
==
NO_DATA
,
NO_DATA_FLOAT
,
(
img
-
mean
)
/
std
)
imgs
.
append
(
img
)
metas
.
append
(
meta
)
if
coords
is
not
None
:
location_coords
.
append
(
coords
)
try
:
match
=
re
.
search
(
r
'(\d{7,8}T\d{6})'
,
file
)
if
match
:
year
=
int
(
match
.
group
(
1
)[:
4
])
julian_day
=
match
.
group
(
1
).
split
(
'T'
)[
0
][
4
:]
if
len
(
julian_day
)
==
3
:
julian_day
=
int
(
julian_day
)
else
:
julian_day
=
datetime
.
datetime
.
strptime
(
julian_day
,
'%m%d'
).
timetuple
().
tm_yday
temporal_coords
.
append
([
year
,
julian_day
])
except
Exception
as
e
:
print
(
f
'Could not extract timestamp for
{
file
}
(
{
e
}
)'
)
imgs
=
np
.
stack
(
imgs
,
axis
=
0
)
# num_frames, H, W, C
imgs
=
np
.
moveaxis
(
imgs
,
-
1
,
0
).
astype
(
"float32"
)
imgs
=
np
.
expand_dims
(
imgs
,
axis
=
0
)
# add batch di
return
imgs
,
temporal_coords
,
location_coords
,
metas
def
run_model
(
input_data
,
temporal_coords
,
location_coords
,
model
,
datamodule
,
img_size
,
lightning_model
=
None
):
# Reflect pad if not divisible by img_size
original_h
,
original_w
=
input_data
.
shape
[
-
2
:]
pad_h
=
(
img_size
-
(
original_h
%
img_size
))
%
img_size
pad_w
=
(
img_size
-
(
original_w
%
img_size
))
%
img_size
input_data
=
np
.
pad
(
input_data
,
((
0
,
0
),
(
0
,
0
),
(
0
,
0
),
(
0
,
pad_h
),
(
0
,
pad_w
)),
mode
=
"reflect"
)
# Build sliding window
batch_size
=
1
batch
=
torch
.
tensor
(
input_data
,
device
=
"cpu"
)
windows
=
(
batch
.
unfold
(
3
,
img_size
,
img_size
).
unfold
(
4
,
img_size
,
img_size
))
h1
,
w1
=
windows
.
shape
[
3
:
5
]
windows
=
rearrange
(
windows
,
"b c t h1 w1 h w -> (b h1 w1) c t h w"
,
h
=
img_size
,
w
=
img_size
)
# Split into batches if number of windows > batch_size
num_batches
=
windows
.
shape
[
0
]
//
batch_size
if
windows
.
shape
[
0
]
>
batch_size
else
1
windows
=
torch
.
tensor_split
(
windows
,
num_batches
,
dim
=
0
)
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda'
)
else
:
device
=
torch
.
device
(
'cpu'
)
if
temporal_coords
:
temporal_coords
=
torch
.
tensor
(
temporal_coords
,
device
=
device
).
unsqueeze
(
0
)
else
:
temporal_coords
=
None
if
location_coords
:
location_coords
=
torch
.
tensor
(
location_coords
[
0
],
device
=
device
).
unsqueeze
(
0
)
else
:
location_coords
=
None
# Run model
pred_imgs
=
[]
for
x
in
windows
:
# Apply standardization
x
=
datamodule
.
test_transform
(
image
=
x
.
squeeze
().
numpy
().
transpose
(
1
,
2
,
0
))
x
=
datamodule
.
aug
(
x
)[
'image'
]
with
torch
.
no_grad
():
x
=
x
.
to
(
device
)
pred
=
model
.
run
(
x
,
location_coords
=
location_coords
)
if
lightning_model
:
pred_lightning
=
lightning_model
(
x
,
temporal_coords
=
temporal_coords
,
location_coords
=
location_coords
)
pred_lightning
=
pred_lightning
.
output
.
detach
().
cpu
()
if
not
torch
.
equal
(
pred
,
pred_lightning
):
print
(
"Inference output is not equal"
)
y_hat
=
pred
.
argmax
(
dim
=
1
)
y_hat
=
torch
.
nn
.
functional
.
interpolate
(
y_hat
.
unsqueeze
(
1
).
float
(),
size
=
img_size
,
mode
=
"nearest"
)
pred_imgs
.
append
(
y_hat
)
pred_imgs
=
torch
.
concat
(
pred_imgs
,
dim
=
0
)
# Build images from patches
pred_imgs
=
rearrange
(
pred_imgs
,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)"
,
h
=
img_size
,
w
=
img_size
,
b
=
1
,
c
=
1
,
h1
=
h1
,
w1
=
w1
,
)
# Cut padded area back to original size
pred_imgs
=
pred_imgs
[...,
:
original_h
,
:
original_w
]
# Squeeze (batch size 1)
pred_imgs
=
pred_imgs
[
0
]
return
pred_imgs
def
main
(
data_file
:
str
,
output_dir
:
str
,
rgb_outputs
:
bool
,
input_indices
:
list
[
int
]
=
None
,
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
# Load model ---------------------------------------------------------------
model_obj
=
PrithviMAE
()
datamodule
=
generate_datamodule
()
img_size
=
256
# Size of Sen1Floods11
# Loading data -------------------------------------------------------------
input_data
,
temporal_coords
,
location_coords
,
meta_data
=
load_example
(
file_paths
=
[
data_file
],
indices
=
input_indices
,
)
meta_data
=
meta_data
[
0
]
# only one image
if
input_data
.
mean
()
>
1
:
input_data
=
input_data
/
10000
# Convert to range 0-1
# Running model ------------------------------------------------------------
channels
=
[
datamodule_config
[
'bands'
].
index
(
b
)
for
b
in
[
"RED"
,
"GREEN"
,
"BLUE"
]
]
# BGR -> RGB
pred
=
run_model
(
input_data
,
temporal_coords
,
location_coords
,
model_obj
,
datamodule
,
img_size
)
# Save pred
meta_data
.
update
(
count
=
1
,
dtype
=
"uint8"
,
compress
=
"lzw"
,
nodata
=
0
)
pred_file
=
os
.
path
.
join
(
output_dir
,
f
"pred_
{
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
}
.tiff"
)
save_geotiff
(
_convert_np_uint8
(
pred
),
pred_file
,
meta_data
)
# Save image + pred
meta_data
.
update
(
count
=
3
,
dtype
=
"uint8"
,
compress
=
"lzw"
,
nodata
=
0
)
if
input_data
.
mean
()
<
1
:
input_data
=
input_data
*
10000
# Scale to 0-10000
rgb_orig
=
process_channel_group
(
orig_img
=
torch
.
Tensor
(
input_data
[
0
,
:,
0
,
...]),
channels
=
channels
,
)
pred
[
pred
==
0.
]
=
np
.
nan
img_pred
=
rgb_orig
*
0.7
+
pred
*
0.3
img_pred
[
img_pred
.
isnan
()]
=
rgb_orig
[
img_pred
.
isnan
()]
img_pred_file
=
os
.
path
.
join
(
output_dir
,
f
"rgb_pred_
{
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
}
.tiff"
)
save_geotiff
(
image
=
_convert_np_uint8
(
img_pred
),
output_path
=
img_pred_file
,
meta
=
meta_data
,
)
# Save image rgb
if
rgb_outputs
:
rgb_file
=
os
.
path
.
join
(
output_dir
,
"original_rgb_"
f
"
{
os
.
path
.
splitext
(
os
.
path
.
basename
(
data_file
))[
0
]
}
.tiff"
)
save_geotiff
(
image
=
_convert_np_uint8
(
rgb_orig
),
output_path
=
rgb_file
,
meta
=
meta_data
,
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
"MAE run inference"
,
add_help
=
False
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
default
=
"./India_900498_S2Hand.tif"
,
help
=
"Path to the file."
,
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"output"
,
help
=
"Path to the directory where to save outputs."
,
)
parser
.
add_argument
(
"--input_indices"
,
default
=
[
1
,
2
,
3
,
8
,
11
,
12
],
type
=
int
,
nargs
=
"+"
,
help
=
"0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data."
,
)
parser
.
add_argument
(
"--rgb_outputs"
,
action
=
"store_true"
,
help
=
"If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved."
,
)
args
=
parser
.
parse_args
()
main
(
**
vars
(
args
))
tests/models/registry.py
View file @
974dfd49
...
@@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
...
@@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"TIGER-Lab/VLM2Vec-Full"
,
"Phi3VForCausalLM"
:
_HfExamplesInfo
(
"TIGER-Lab/VLM2Vec-Full"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"MrLight/dse-qwen2-2b-mrl-v1"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
_HfExamplesInfo
(
"MrLight/dse-qwen2-2b-mrl-v1"
),
# noqa: E501
# The model on Huggingface is currently being updated,
# hence I temporarily mark it as not available online
"PrithviGeoSpatialMAE"
:
_HfExamplesInfo
(
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
,
# noqa: E501
is_available_online
=
False
),
}
}
_CROSS_ENCODER_EXAMPLE_MODELS
=
{
_CROSS_ENCODER_EXAMPLE_MODELS
=
{
...
...
vllm/attention/backends/placeholder_attn.py
View file @
974dfd49
...
@@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
...
@@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
-1 if cuda graph is not used.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
batch_size: The maybe padded batch size.
"""
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
# Some input builders such as ModelInputForCPUBuilder do not have the
self
.
input_builder
.
chunked_prefill_enabled
)
# "inter_data_list" attribute.
# Let's check inter_data_list exists before we reference it.
if
hasattr
(
self
.
input_builder
,
"inter_data_list"
):
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
...
...
vllm/inputs/preprocess.py
View file @
974dfd49
...
@@ -254,8 +254,14 @@ class InputPreprocessor:
...
@@ -254,8 +254,14 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
returning the corresponding token IDs and metadata.
"""
"""
tokenizer_group
=
self
.
get_tokenizer_group
()
# At the moment on model (PrithviGeoSpatialMAE) requires to be
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
# initialized without a tokenizer while using also multi-modal
# input.
if
not
self
.
tokenizer
:
tokenizer
=
None
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
)
self
.
model_config
,
tokenizer
)
...
@@ -273,9 +279,15 @@ class InputPreprocessor:
...
@@ -273,9 +279,15 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
"""Async version of :meth:`_process_multimodal`."""
"""Async version of :meth:`_process_multimodal`."""
tokenizer_group
=
self
.
get_tokenizer_group
()
# At the moment on model (PrithviGeoSpatialMAE) requires to be
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
# initialized without a tokenizer while using also multi-modal
)
# input.
if
not
self
.
tokenizer
:
tokenizer
=
None
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
)
self
.
model_config
,
tokenizer
)
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
0 → 100644
View file @
974dfd49
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
SupportsMultiModal
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
PoolingSequenceGroupOutput
)
class
PrithviGeoSpatialMAEProcessingInfo
(
BaseProcessingInfo
):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
pass
class
PrithviGeoSpatialMAEInputBuilder
(
BaseDummyInputsBuilder
[
PrithviGeoSpatialMAEProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
return
ProcessorInputs
(
prompt_text
=
""
,
# This model input is fixed and is in the form of a torch Tensor.
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
mm_data
=
{
"pixel_values"
:
torch
.
full
((
1
,
6
,
512
,
512
),
1.0
),
"location_coords"
:
torch
.
full
((
1
,
2
),
1.0
)
})
class
PrithviGeoSpatialMAEMultiModalProcessor
(
BaseMultiModalProcessor
):
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
location_coords
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
pass
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
pass
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalInputs
:
mm_kwargs
=
{}
for
k
,
v
in
mm_data
.
items
():
mm_kwargs
[
k
]
=
v
return
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
prompt
,
prompt_token_ids
=
[
1
],
mm_kwargs
=
MultiModalKwargs
(
mm_kwargs
),
mm_placeholders
=
{},
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
PrithviGeoSpatialMAEMultiModalProcessor
,
info
=
PrithviGeoSpatialMAEProcessingInfo
,
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
):
""" Prithvi Masked Autoencoder"""
def
_instantiate_model
(
self
,
config
:
dict
)
->
nn
.
Module
|
None
:
# We might be able/need to support different tasks with this same model
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
from
terratorch.cli_tools
import
SemanticSegmentationTask
task
=
SemanticSegmentationTask
(
config
[
"model_args"
],
config
[
"task_args"
][
"model_factory"
],
loss
=
config
[
"task_args"
][
"loss"
],
lr
=
config
[
"task_args"
][
"lr"
],
ignore_index
=
config
[
"task_args"
][
"ignore_index"
],
optimizer
=
config
[
"task_args"
][
"optimizer"
],
optimizer_hparams
=
config
[
"optimizer_params"
],
scheduler
=
config
[
"task_args"
][
"scheduler"
],
scheduler_hparams
=
config
[
"scheduler_params"
],
plot_on_val
=
config
[
"task_args"
][
"plot_on_val"
],
freeze_decoder
=
config
[
"task_args"
][
"freeze_decoder"
],
freeze_backbone
=
config
[
"task_args"
][
"freeze_backbone"
])
return
task
.
model
else
:
return
None
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
# the actual model is dynamically instantiated using terratorch
# allowing us to perform changes to the model architecture
# at startup time (e.g., change the model decoder class.)
self
.
model
=
self
.
_instantiate_model
(
vllm_config
.
model_config
.
hf_config
.
to_dict
()[
"pretrained_cfg"
])
if
self
.
model
is
None
:
raise
ValueError
(
"Unsupported task."
"Only SemanticSegmentationTask is supported for now"
"by PrithviGeospatialMAE."
)
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
f
"Incorrect type of pixel_values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_values
=
torch
.
unbind
(
pixel_values
,
dim
=
0
)[
0
]
location_coords
=
kwargs
.
pop
(
"location_coords"
,
None
)
if
not
isinstance
(
location_coords
,
torch
.
Tensor
):
raise
ValueError
(
f
"Incorrect type of location_coords. "
f
"Got type:
{
type
(
location_coords
)
}
"
)
location_coords
=
torch
.
unbind
(
location_coords
,
dim
=
0
)[
0
]
if
location_coords
.
shape
==
torch
.
Size
([
0
]):
location_coords
=
None
return
pixel_values
,
location_coords
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
):
pixel_values
,
location_coords
=
(
self
.
_parse_and_validate_multimodal_data
(
**
kwargs
))
model_output
=
self
.
model
(
pixel_values
,
location_coords
=
location_coords
)
return
model_output
.
output
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
PoolerOutput
([
PoolingSequenceGroupOutput
(
hidden_states
)])
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_list
=
[]
model_buffers
=
dict
(
self
.
named_buffers
())
loaded_buffers
=
[]
for
key
,
value
in
weights
:
if
key
==
"state_dict"
:
weights_to_parse
=
value
for
name
,
weight
in
weights_to_parse
.
items
():
if
"pos_embed"
in
name
:
continue
if
"_timm_module."
in
name
:
name
=
name
.
replace
(
"_timm_module."
,
""
)
# this model requires a couple of buffers to be loaded
# that are not loadable with the AutoWeightsLoader
if
name
in
model_buffers
:
if
"_timm_module."
in
name
:
name
=
name
.
replace
(
"_timm_module."
,
""
)
buffer
=
model_buffers
[
name
]
weight_loader
=
getattr
(
buffer
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
buffer
,
weight
)
loaded_buffers
.
append
(
name
)
else
:
params_list
.
append
((
name
,
weight
))
break
# Load the remaining model parameters
loader
=
AutoWeightsLoader
(
self
)
autoloaded_weights
=
loader
.
load_weights
(
params_list
)
return
autoloaded_weights
.
union
(
set
(
loaded_buffers
))
vllm/model_executor/models/registry.py
View file @
974dfd49
...
@@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
...
@@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
# noqa: E501
# [Auto-converted (see adapters.py)]
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForSequenceClassification"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE"
:
(
"prithvi_geospatial_mae"
,
"PrithviGeoSpatialMAE"
),
}
}
_CROSS_ENCODER_MODELS
=
{
_CROSS_ENCODER_MODELS
=
{
...
...
vllm/worker/pooling_model_runner.py
View file @
974dfd49
...
@@ -74,7 +74,16 @@ class PoolingModelRunner(
...
@@ -74,7 +74,16 @@ class PoolingModelRunner(
prefill_meta
=
model_input
.
attn_metadata
.
prefill_metadata
prefill_meta
=
model_input
.
attn_metadata
.
prefill_metadata
decode_meta
=
model_input
.
attn_metadata
.
decode_metadata
decode_meta
=
model_input
.
attn_metadata
.
decode_metadata
virtual_engine
=
model_input
.
virtual_engine
virtual_engine
=
model_input
.
virtual_engine
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
# Pooling models are (ab-)used also to integrate non text models that
# are not autoregressive (PrithviGeosaptialMAE).
# These model might not use attention and do not really have a prefill
# and decode phase. The model input is processed in one shot and both
# decode_metadata and prefill_metadata would be None for such models.
# See the PlaceholderAttentionMetadata class.
# TODO: Figure out if cuda_graph is of any use for these models and
# explore how to leverage it.
if
(
prefill_meta
is
None
and
decode_meta
is
not
None
and
decode_meta
.
use_cuda_graph
):
assert
model_input
.
input_tokens
is
not
None
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
virtual_engine
][
model_executable
=
self
.
graph_runners
[
virtual_engine
][
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment