Unverified Commit 8560a5b2 authored by Christian Pinto's avatar Christian Pinto Committed by GitHub
Browse files

[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)


Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
parent 316b1bf7
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
""" # noqa: E501
import argparse import argparse
import datetime import datetime
import os import os
import re
from typing import Union from typing import Union
import albumentations import albumentations
import numpy as np import numpy as np
import rasterio import rasterio
import regex as re
import torch import torch
from einops import rearrange from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm import LLM from vllm import LLM
torch.set_default_dtype(torch.float16)
NO_DATA = -9999 NO_DATA = -9999
NO_DATA_FLOAT = 0.0001 NO_DATA_FLOAT = 0.0001
OFFSET = 0 OFFSET = 0
PERCENTILE = 99 PERCENTILE = 99
model_config = """{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with open(
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
config_file.write(model_config)
datamodule_config = { datamodule_config = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size": 16, "batch_size": 16,
...@@ -138,28 +43,24 @@ datamodule_config = { ...@@ -138,28 +43,24 @@ datamodule_config = {
class PrithviMAE: class PrithviMAE:
def __init__(self): def __init__(self, model):
print("Initializing PrithviMAE model") self.model = LLM(
self.llm = LLM( model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
) )
def run(self, input_data, location_coords): def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure # merge the inputs into one data structure
if input_data is not None and input_data.dtype == torch.float32:
input_data = input_data.to(torch.float16)
input_data = input_data[0]
mm_data = { mm_data = {
"pixel_values": torch.empty(0) if input_data is None else input_data, "pixel_values": input_data,
"location_coords": torch.empty(0) "location_coords": location_coords,
if location_coords is None
else location_coords,
} }
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False)
outputs = self.llm.encode(prompt, use_tqdm=False)
print("################ Inference done (it took seconds) ##############")
return outputs[0].outputs.data return outputs[0].outputs.data
...@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels): ...@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
""" """
Args: Args:
orig_img: torch.Tensor representing original image (reference) orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W). with shape = (bands, H, W).
channels: list of indices representing RGB channels. channels: list of indices representing RGB channels.
Returns: Returns:
torch.Tensor with shape (num_channels, height, width) for original image torch.Tensor with shape (num_channels, height, width)
for original image
""" """
orig_img = orig_img[channels, ...] orig_img = orig_img[channels, ...]
...@@ -260,10 +162,10 @@ def load_example( ...@@ -260,10 +162,10 @@ def load_example(
Args: Args:
file_paths: list of file paths . file_paths: list of file paths .
mean: list containing mean values for each band in the images mean: list containing mean values for each band in the
in *file_paths*. images in *file_paths*.
std: list containing std values for each band in the images std: list containing std values for each band in the
in *file_paths*. images in *file_paths*.
Returns: Returns:
np.array containing created example np.array containing created example
...@@ -308,7 +210,7 @@ def load_example( ...@@ -308,7 +210,7 @@ def load_example(
print(f"Could not extract timestamp for {file} ({e})") print(f"Could not extract timestamp for {file} ({e})")
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32") imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas return imgs, temporal_coords, location_coords, metas
...@@ -332,8 +234,10 @@ def run_model( ...@@ -332,8 +234,10 @@ def run_model(
) )
# Build sliding window # Build sliding window
batch_size = 1 batch_size = 1
batch = torch.tensor(input_data, device="cpu") # batch = torch.tensor(input_data, device="cpu")
batch = torch.tensor(input_data)
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
h1, w1 = windows.shape[3:5] h1, w1 = windows.shape[3:5]
windows = rearrange( windows = rearrange(
...@@ -344,18 +248,16 @@ def run_model( ...@@ -344,18 +248,16 @@ def run_model(
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0) windows = torch.tensor_split(windows, num_batches, dim=0)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if temporal_coords: if temporal_coords:
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else: else:
temporal_coords = None temporal_coords = None
if location_coords: if location_coords:
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else: else:
location_coords = None location_coords = None
# Run model # Run Prithvi-EO-V2-300M-TL-Sen1Floods11
pred_imgs = [] pred_imgs = []
for x in windows: for x in windows:
# Apply standardization # Apply standardization
...@@ -363,15 +265,7 @@ def run_model( ...@@ -363,15 +265,7 @@ def run_model(
x = datamodule.aug(x)["image"] x = datamodule.aug(x)["image"]
with torch.no_grad(): with torch.no_grad():
x = x.to(device)
pred = model.run(x, location_coords=location_coords) pred = model.run(x, location_coords=location_coords)
if lightning_model:
pred_lightning = lightning_model(
x, temporal_coords=temporal_coords, location_coords=location_coords
)
pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning):
print("Inference output is not equal")
y_hat = pred.argmax(dim=1) y_hat = pred.argmax(dim=1)
y_hat = torch.nn.functional.interpolate( y_hat = torch.nn.functional.interpolate(
...@@ -403,52 +297,18 @@ def run_model( ...@@ -403,52 +297,18 @@ def run_model(
return pred_imgs return pred_imgs
def parse_args():
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
def main( def main(
data_file: str, data_file: str,
model: str,
output_dir: str, output_dir: str,
rgb_outputs: bool, rgb_outputs: bool,
input_indices: list[int] = None, input_indices: list[int] = None,
): ):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Load model --------------------------------------------------------------- model_obj = PrithviMAE(model=model)
model_obj = PrithviMAE()
datamodule = generate_datamodule() datamodule = generate_datamodule()
img_size = 256 # Size of Sen1Floods11 img_size = 512 # Size of Sen1Floods11
# Loading data -------------------------------------------------------------
input_data, temporal_coords, location_coords, meta_data = load_example( input_data, temporal_coords, location_coords, meta_data = load_example(
file_paths=[data_file], file_paths=[data_file],
...@@ -460,8 +320,6 @@ def main( ...@@ -460,8 +320,6 @@ def main(
if input_data.mean() > 1: if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1 input_data = input_data / 10000 # Convert to range 0-1
# Running model ------------------------------------------------------------
channels = [ channels = [
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB ] # BGR -> RGB
...@@ -469,7 +327,6 @@ def main( ...@@ -469,7 +327,6 @@ def main(
pred = run_model( pred = run_model(
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
) )
# Save pred # Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join( pred_file = os.path.join(
...@@ -487,6 +344,7 @@ def main( ...@@ -487,6 +344,7 @@ def main(
orig_img=torch.Tensor(input_data[0, :, 0, ...]), orig_img=torch.Tensor(input_data[0, :, 0, ...]),
channels=channels, channels=channels,
) )
rgb_orig = rgb_orig.to(torch.float32)
pred[pred == 0.0] = np.nan pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3 img_pred = rgb_orig * 0.7 + pred * 0.3
...@@ -503,9 +361,10 @@ def main( ...@@ -503,9 +361,10 @@ def main(
# Save image rgb # Save image rgb
if rgb_outputs: if rgb_outputs:
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
rgb_file = os.path.join( rgb_file = os.path.join(
output_dir, output_dir,
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff", f"original_rgb_{name_suffix}.tiff",
) )
save_geotiff( save_geotiff(
image=_convert_np_uint8(rgb_orig), image=_convert_np_uint8(rgb_orig),
...@@ -515,6 +374,42 @@ def main( ...@@ -515,6 +374,42 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--model",
type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
help="Path to a checkpoint file to load from.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="""
0-based indices of the six Prithvi channels to be selected from the input.
By default selects [1,2,3,8,11,12] for S2L1C data.
""",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
args = parser.parse_args()
main(**vars(args)) main(**vars(args))
...@@ -54,3 +54,4 @@ runai-model-streamer==0.11.0 ...@@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0 runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10 fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10 pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
\ No newline at end of file
...@@ -6,6 +6,10 @@ accelerate==1.0.1 ...@@ -6,6 +6,10 @@ accelerate==1.0.1
# via # via
# lm-eval # lm-eval
# peft # peft
aenum==3.1.16
# via lightly
affine==2.4.0
# via rasterio
aiohappyeyeballs==2.4.3 aiohappyeyeballs==2.4.3
# via aiohttp # via aiohttp
aiohttp==3.10.11 aiohttp==3.10.11
...@@ -21,8 +25,18 @@ aiosignal==1.3.1 ...@@ -21,8 +25,18 @@ aiosignal==1.3.1
# via # via
# aiohttp # aiohttp
# ray # ray
albucore==0.0.16
# via terratorch
albumentations==1.4.6
# via terratorch
alembic==1.16.4
# via mlflow
annotated-types==0.7.0 annotated-types==0.7.0
# via pydantic # via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.6.2.post1 anyio==4.6.2.post1
# via # via
# httpx # httpx
...@@ -34,10 +48,12 @@ arrow==1.3.0 ...@@ -34,10 +48,12 @@ arrow==1.3.0
attrs==24.2.0 attrs==24.2.0
# via # via
# aiohttp # aiohttp
# fiona
# hypothesis # hypothesis
# jsonlines # jsonlines
# jsonschema # jsonschema
# pytest-subtests # pytest-subtests
# rasterio
# referencing # referencing
audioread==3.0.1 audioread==3.0.1
# via librosa # via librosa
...@@ -46,9 +62,13 @@ backoff==2.2.1 ...@@ -46,9 +62,13 @@ backoff==2.2.1
# -r requirements/test.in # -r requirements/test.in
# schemathesis # schemathesis
bitsandbytes==0.46.1 bitsandbytes==0.46.1
# via -r requirements/test.in # via
# -r requirements/test.in
# lightning
black==24.10.0 black==24.10.0
# via datamodel-code-generator # via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0 blobfile==3.0.0
# via -r requirements/test.in # via -r requirements/test.in
bm25s==0.2.13 bm25s==0.2.13
...@@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3 ...@@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9 buildkite-test-collector==0.1.9
# via -r requirements/test.in # via -r requirements/test.in
cachetools==5.5.2 cachetools==5.5.2
# via google-auth # via
# google-auth
# mlflow-skinny
certifi==2024.8.30 certifi==2024.8.30
# via # via
# fiona
# httpcore # httpcore
# httpx # httpx
# lightly
# pyogrio
# pyproj
# rasterio
# requests # requests
cffi==1.17.1 cffi==1.17.1
# via soundfile # via soundfile
...@@ -79,11 +106,28 @@ charset-normalizer==3.4.0 ...@@ -79,11 +106,28 @@ charset-normalizer==3.4.0
click==8.1.7 click==8.1.7
# via # via
# black # black
# click-plugins
# cligj
# fiona
# flask
# jiwer # jiwer
# mlflow-skinny
# nltk # nltk
# rasterio
# ray # ray
# schemathesis # schemathesis
# typer # typer
# uvicorn
click-plugins==1.1.1.2
# via
# fiona
# rasterio
cligj==0.7.2
# via
# fiona
# rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6 colorama==0.4.6
# via # via
# sacrebleu # sacrebleu
...@@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0 ...@@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
# via ray # via ray
cycler==0.12.1 cycler==0.12.1
# via matplotlib # via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3 datamodel-code-generator==0.26.3
# via -r requirements/test.in # via -r requirements/test.in
dataproperty==1.0.1 dataproperty==1.0.1
...@@ -122,13 +168,21 @@ distlib==0.3.9 ...@@ -122,13 +168,21 @@ distlib==0.3.9
# via virtualenv # via virtualenv
dnspython==2.7.0 dnspython==2.7.0
# via email-validator # via email-validator
docker==7.1.0
# via mlflow
docopt==0.6.2 docopt==0.6.2
# via num2words # via num2words
einops==0.8.0 docstring-parser==0.17.0
# via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# encodec # encodec
# mamba-ssm # mamba-ssm
# terratorch
# torchgeo
# vector-quantize-pytorch # vector-quantize-pytorch
# vocos # vocos
einx==0.3.0 einx==0.3.0
...@@ -141,6 +195,8 @@ eval-type-backport==0.2.2 ...@@ -141,6 +195,8 @@ eval-type-backport==0.2.2
# via mteb # via mteb
evaluate==0.4.3 evaluate==0.4.3
# via lm-eval # via lm-eval
fastapi==0.116.1
# via mlflow-skinny
fastparquet==2024.11.0 fastparquet==2024.11.0
# via genai-perf # via genai-perf
fastrlock==0.8.2 fastrlock==0.8.2
...@@ -156,6 +212,10 @@ filelock==3.16.1 ...@@ -156,6 +212,10 @@ filelock==3.16.1
# torch # torch
# transformers # transformers
# virtualenv # virtualenv
fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.54.1 fonttools==4.54.1
# via matplotlib # via matplotlib
fqdn==1.5.1 fqdn==1.5.1
...@@ -173,6 +233,8 @@ fsspec==2024.9.0 ...@@ -173,6 +233,8 @@ fsspec==2024.9.0
# evaluate # evaluate
# fastparquet # fastparquet
# huggingface-hub # huggingface-hub
# lightning
# pytorch-lightning
# torch # torch
ftfy==6.3.1 ftfy==6.3.1
# via open-clip-torch # via open-clip-torch
...@@ -180,18 +242,41 @@ genai-perf==0.0.8 ...@@ -180,18 +242,41 @@ genai-perf==0.0.8
# via -r requirements/test.in # via -r requirements/test.in
genson==1.3.0 genson==1.3.0
# via datamodel-code-generator # via datamodel-code-generator
geopandas==1.0.1
# via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
# via mlflow-skinny
google-api-core==2.24.2 google-api-core==2.24.2
# via opencensus # via opencensus
google-auth==2.40.2 google-auth==2.40.2
# via google-api-core # via
# databricks-sdk
# google-api-core
googleapis-common-protos==1.70.0 googleapis-common-protos==1.70.0
# via google-api-core # via google-api-core
graphene==3.4.3
# via mlflow
graphql-core==3.2.6 graphql-core==3.2.6
# via hypothesis-graphql # via
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
greenlet==3.2.3
# via sqlalchemy
grpcio==1.71.0 grpcio==1.71.0
# via ray # via ray
gunicorn==23.0.0
# via mlflow
h11==0.14.0 h11==0.14.0
# via httpcore # via
# httpcore
# uvicorn
h5py==3.13.0
# via terratorch
harfile==0.3.0 harfile==0.3.0
# via schemathesis # via schemathesis
hf-xet==1.1.3 hf-xet==1.1.3
...@@ -204,7 +289,7 @@ httpx==0.27.2 ...@@ -204,7 +289,7 @@ httpx==0.27.2
# via # via
# -r requirements/test.in # -r requirements/test.in
# schemathesis # schemathesis
huggingface-hub==0.33.0 huggingface-hub==0.33.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# accelerate # accelerate
...@@ -212,13 +297,19 @@ huggingface-hub==0.33.0 ...@@ -212,13 +297,19 @@ huggingface-hub==0.33.0
# evaluate # evaluate
# open-clip-torch # open-clip-torch
# peft # peft
# segmentation-models-pytorch
# sentence-transformers # sentence-transformers
# terratorch
# timm # timm
# tokenizers # tokenizers
# transformers # transformers
# vocos # vocos
humanize==4.11.0 humanize==4.11.0
# via runai-model-streamer # via runai-model-streamer
hydra-core==1.3.2
# via
# lightly
# lightning
hypothesis==6.131.0 hypothesis==6.131.0
# via # via
# hypothesis-graphql # hypothesis-graphql
...@@ -236,6 +327,14 @@ idna==3.10 ...@@ -236,6 +327,14 @@ idna==3.10
# jsonschema # jsonschema
# requests # requests
# yarl # yarl
imageio==2.37.0
# via scikit-image
importlib-metadata==8.7.0
# via
# mlflow-skinny
# opentelemetry-api
importlib-resources==6.5.2
# via typeshed-client
inflect==5.6.2 inflect==5.6.2
# via datamodel-code-generator # via datamodel-code-generator
iniconfig==2.0.0 iniconfig==2.0.0
...@@ -244,9 +343,13 @@ isoduration==20.11.0 ...@@ -244,9 +343,13 @@ isoduration==20.11.0
# via jsonschema # via jsonschema
isort==5.13.2 isort==5.13.2
# via datamodel-code-generator # via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6 jinja2==3.1.6
# via # via
# datamodel-code-generator # datamodel-code-generator
# flask
# mlflow
# torch # torch
jiwer==3.0.5 jiwer==3.0.5
# via -r requirements/test.in # via -r requirements/test.in
...@@ -259,6 +362,10 @@ joblib==1.4.2 ...@@ -259,6 +362,10 @@ joblib==1.4.2
# librosa # librosa
# nltk # nltk
# scikit-learn # scikit-learn
jsonargparse==4.35.0
# via
# lightning
# terratorch
jsonlines==4.0.0 jsonlines==4.0.0
# via lm-eval # via lm-eval
jsonpointer==3.0.0 jsonpointer==3.0.0
...@@ -277,12 +384,33 @@ kaleido==0.2.1 ...@@ -277,12 +384,33 @@ kaleido==0.2.1
# via genai-perf # via genai-perf
kiwisolver==1.4.7 kiwisolver==1.4.7
# via matplotlib # via matplotlib
kornia==0.8.1
# via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4 lazy-loader==0.4
# via librosa # via
# librosa
# scikit-image
libnacl==2.1.0 libnacl==2.1.0
# via tensorizer # via tensorizer
librosa==0.10.2.post1 librosa==0.10.2.post1
# via -r requirements/test.in # via -r requirements/test.in
lightly==1.5.20
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.5.1.post0
# via
# terratorch
# torchgeo
lightning-utilities==0.14.3
# via
# lightning
# pytorch-lightning
# torchmetrics
llvmlite==0.44.0 llvmlite==0.44.0
# via numba # via numba
lm-eval==0.4.8 lm-eval==0.4.8
...@@ -291,16 +419,27 @@ lxml==5.3.0 ...@@ -291,16 +419,27 @@ lxml==5.3.0
# via # via
# blobfile # blobfile
# sacrebleu # sacrebleu
mako==1.3.10
# via alembic
mamba-ssm==2.2.4 mamba-ssm==2.2.4
# via -r requirements/test.in # via -r requirements/test.in
markdown==3.8.2
# via mlflow
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich # via rich
markupsafe==3.0.1 markupsafe==3.0.1
# via # via
# flask
# jinja2 # jinja2
# mako
# werkzeug # werkzeug
matplotlib==3.9.2 matplotlib==3.9.2
# via -r requirements/test.in # via
# -r requirements/test.in
# lightning
# mlflow
# pycocotools
# torchgeo
mbstrdecoder==1.1.3 mbstrdecoder==1.1.3
# via # via
# dataproperty # dataproperty
...@@ -310,6 +449,10 @@ mdurl==0.1.2 ...@@ -310,6 +449,10 @@ mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.8.0 mistral-common==1.8.0
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0 more-itertools==10.5.0
# via lm-eval # via lm-eval
mpmath==1.3.0 mpmath==1.3.0
...@@ -328,10 +471,14 @@ multiprocess==0.70.16 ...@@ -328,10 +471,14 @@ multiprocess==0.70.16
# via # via
# datasets # datasets
# evaluate # evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via black # via black
networkx==3.2.1 networkx==3.2.1
# via torch # via
# scikit-image
# torch
ninja==1.11.1.3 ninja==1.11.1.3
# via mamba-ssm # via mamba-ssm
nltk==3.9.1 nltk==3.9.1
...@@ -348,6 +495,8 @@ numpy==1.26.4 ...@@ -348,6 +495,8 @@ numpy==1.26.4
# via # via
# -r requirements/test.in # -r requirements/test.in
# accelerate # accelerate
# albucore
# albumentations
# bitsandbytes # bitsandbytes
# bm25s # bm25s
# contourpy # contourpy
...@@ -358,9 +507,15 @@ numpy==1.26.4 ...@@ -358,9 +507,15 @@ numpy==1.26.4
# evaluate # evaluate
# fastparquet # fastparquet
# genai-perf # genai-perf
# geopandas
# h5py
# imageio
# librosa # librosa
# lightly
# lightly-utils
# matplotlib # matplotlib
# mistral-common # mistral-common
# mlflow
# mteb # mteb
# numba # numba
# numexpr # numexpr
...@@ -368,18 +523,30 @@ numpy==1.26.4 ...@@ -368,18 +523,30 @@ numpy==1.26.4
# pandas # pandas
# patsy # patsy
# peft # peft
# pycocotools
# pyogrio
# rasterio
# rioxarray
# rouge-score # rouge-score
# runai-model-streamer # runai-model-streamer
# sacrebleu # sacrebleu
# scikit-image
# scikit-learn # scikit-learn
# scipy # scipy
# segmentation-models-pytorch
# shapely
# soxr # soxr
# statsmodels # statsmodels
# tensorboardx
# tensorizer # tensorizer
# tifffile
# torchgeo
# torchmetrics
# torchvision # torchvision
# transformers # transformers
# tritonclient # tritonclient
# vocos # vocos
# xarray
nvidia-cublas-cu12==12.8.3.14 nvidia-cublas-cu12==12.8.3.14
# via # via
# nvidia-cudnn-cu12 # nvidia-cudnn-cu12
...@@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61 ...@@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch # torch
nvidia-nvtx-cu12==12.8.55 nvidia-nvtx-cu12==12.8.55
# via torch # via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
open-clip-torch==2.32.0 open-clip-torch==2.32.0
# via -r requirements/test.in # via -r requirements/test.in
opencensus==0.11.4 opencensus==0.11.4
...@@ -426,7 +597,18 @@ opencensus-context==0.1.3 ...@@ -426,7 +597,18 @@ opencensus-context==0.1.3
opencv-python-headless==4.11.0.86 opencv-python-headless==4.11.0.86
# via # via
# -r requirements/test.in # -r requirements/test.in
# albucore
# albumentations
# mistral-common # mistral-common
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-sdk==1.35.0
# via mlflow-skinny
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2 packaging==24.2
# via # via
# accelerate # accelerate
...@@ -435,26 +617,44 @@ packaging==24.2 ...@@ -435,26 +617,44 @@ packaging==24.2
# datasets # datasets
# evaluate # evaluate
# fastparquet # fastparquet
# geopandas
# gunicorn
# huggingface-hub # huggingface-hub
# hydra-core
# kornia
# lazy-loader # lazy-loader
# lightning
# lightning-utilities
# mamba-ssm # mamba-ssm
# matplotlib # matplotlib
# mlflow-skinny
# peft # peft
# plotly # plotly
# pooch # pooch
# pyogrio
# pytest # pytest
# pytest-rerunfailures # pytest-rerunfailures
# pytorch-lightning
# ray # ray
# rioxarray
# scikit-image
# statsmodels # statsmodels
# tensorboardx
# torchmetrics
# transformers # transformers
# typepy # typepy
# xarray
pandas==2.2.3 pandas==2.2.3
# via # via
# datasets # datasets
# evaluate # evaluate
# fastparquet # fastparquet
# genai-perf # genai-perf
# geopandas
# mlflow
# statsmodels # statsmodels
# torchgeo
# xarray
pathspec==0.12.1 pathspec==0.12.1
# via black # via black
pathvalidate==3.2.1 pathvalidate==3.2.1
...@@ -468,9 +668,14 @@ peft==0.13.2 ...@@ -468,9 +668,14 @@ peft==0.13.2
pillow==10.4.0 pillow==10.4.0
# via # via
# genai-perf # genai-perf
# imageio
# lightly-utils
# matplotlib # matplotlib
# mistral-common # mistral-common
# scikit-image
# segmentation-models-pytorch
# sentence-transformers # sentence-transformers
# torchgeo
# torchvision # torchvision
platformdirs==4.3.6 platformdirs==4.3.6
# via # via
...@@ -489,6 +694,8 @@ portalocker==2.10.1 ...@@ -489,6 +694,8 @@ portalocker==2.10.1
# via sacrebleu # via sacrebleu
pqdm==0.2.0 pqdm==0.2.0
# via -r requirements/test.in # via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0 prometheus-client==0.22.0
# via ray # via ray
propcache==0.2.0 propcache==0.2.0
...@@ -499,8 +706,10 @@ protobuf==5.28.3 ...@@ -499,8 +706,10 @@ protobuf==5.28.3
# via # via
# google-api-core # google-api-core
# googleapis-common-protos # googleapis-common-protos
# mlflow-skinny
# proto-plus # proto-plus
# ray # ray
# tensorboardx
# tensorizer # tensorizer
psutil==6.1.0 psutil==6.1.0
# via # via
...@@ -515,6 +724,7 @@ pyarrow==18.0.0 ...@@ -515,6 +724,7 @@ pyarrow==18.0.0
# via # via
# datasets # datasets
# genai-perf # genai-perf
# mlflow
pyasn1==0.6.1 pyasn1==0.6.1
# via # via
# pyasn1-modules # pyasn1-modules
...@@ -523,6 +733,8 @@ pyasn1-modules==0.4.2 ...@@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
# via google-auth # via google-auth
pybind11==2.13.6 pybind11==2.13.6
# via lm-eval # via lm-eval
pycocotools==2.0.8
# via terratorch
pycountry==24.6.1 pycountry==24.6.1
# via pydantic-extra-types # via pydantic-extra-types
pycparser==2.22 pycparser==2.22
...@@ -532,8 +744,12 @@ pycryptodomex==3.22.0 ...@@ -532,8 +744,12 @@ pycryptodomex==3.22.0
pydantic==2.11.5 pydantic==2.11.5
# via # via
# -r requirements/test.in # -r requirements/test.in
# albumentations
# datamodel-code-generator # datamodel-code-generator
# fastapi
# lightly
# mistral-common # mistral-common
# mlflow-skinny
# mteb # mteb
# pydantic-extra-types # pydantic-extra-types
# ray # ray
...@@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5 ...@@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
# via mistral-common # via mistral-common
pygments==2.18.0 pygments==2.18.0
# via rich # via rich
pyogrio==0.11.0
# via geopandas
pyparsing==3.2.0 pyparsing==3.2.0
# via matplotlib # via
# matplotlib
# rasterio
pyproj==3.7.1
# via
# geopandas
# rioxarray
# torchgeo
pyrate-limiter==3.7.0 pyrate-limiter==3.7.0
# via schemathesis # via schemathesis
pystemmer==3.0.0 pystemmer==3.0.0
# via mteb # via mteb
pytablewriter==1.2.0 pytablewriter==1.2.0
# via lm-eval # via lm-eval
pytest==8.3.3 pytest==8.3.5
# via # via
# -r requirements/test.in # -r requirements/test.in
# buildkite-test-collector # buildkite-test-collector
...@@ -564,6 +789,7 @@ pytest==8.3.3 ...@@ -564,6 +789,7 @@ pytest==8.3.3
# pytest-subtests # pytest-subtests
# pytest-timeout # pytest-timeout
# schemathesis # schemathesis
# terratorch
pytest-asyncio==0.24.0 pytest-asyncio==0.24.0
# via -r requirements/test.in # via -r requirements/test.in
pytest-forked==1.6.0 pytest-forked==1.6.0
...@@ -578,15 +804,23 @@ pytest-subtests==0.14.1 ...@@ -578,15 +804,23 @@ pytest-subtests==0.14.1
# via schemathesis # via schemathesis
pytest-timeout==2.3.1 pytest-timeout==2.3.1
# via -r requirements/test.in # via -r requirements/test.in
python-box==7.3.2
# via terratorch
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via # via
# arrow # arrow
# botocore # botocore
# graphene
# lightly
# matplotlib # matplotlib
# pandas # pandas
# typepy # typepy
python-rapidjson==1.20 python-rapidjson==1.20
# via tritonclient # via tritonclient
pytorch-lightning==2.5.2
# via
# lightly
# lightning
pytrec-eval-terrier==0.5.7 pytrec-eval-terrier==0.5.7
# via mteb # via mteb
pytz==2024.2 pytz==2024.2
...@@ -596,11 +830,17 @@ pytz==2024.2 ...@@ -596,11 +830,17 @@ pytz==2024.2
pyyaml==6.0.2 pyyaml==6.0.2
# via # via
# accelerate # accelerate
# albumentations
# datamodel-code-generator # datamodel-code-generator
# datasets # datasets
# genai-perf # genai-perf
# huggingface-hub # huggingface-hub
# jsonargparse
# lightning
# mlflow-skinny
# omegaconf
# peft # peft
# pytorch-lightning
# ray # ray
# responses # responses
# schemathesis # schemathesis
...@@ -609,6 +849,11 @@ pyyaml==6.0.2 ...@@ -609,6 +849,11 @@ pyyaml==6.0.2
# vocos # vocos
rapidfuzz==3.12.1 rapidfuzz==3.12.1
# via jiwer # via jiwer
rasterio==1.4.3
# via
# rioxarray
# terratorch
# torchgeo
ray==2.43.0 ray==2.43.0
# via -r requirements/test.in # via -r requirements/test.in
redis==5.2.0 redis==5.2.0
...@@ -627,12 +872,16 @@ regex==2024.9.11 ...@@ -627,12 +872,16 @@ regex==2024.9.11
requests==2.32.3 requests==2.32.3
# via # via
# buildkite-test-collector # buildkite-test-collector
# databricks-sdk
# datasets # datasets
# docker
# evaluate # evaluate
# google-api-core # google-api-core
# huggingface-hub # huggingface-hub
# lightly
# lm-eval # lm-eval
# mistral-common # mistral-common
# mlflow-skinny
# mteb # mteb
# pooch # pooch
# ray # ray
...@@ -650,8 +899,11 @@ rfc3987==1.3.8 ...@@ -650,8 +899,11 @@ rfc3987==1.3.8
rich==13.9.4 rich==13.9.4
# via # via
# genai-perf # genai-perf
# lightning
# mteb # mteb
# typer # typer
rioxarray==0.19.0
# via terratorch
rouge-score==0.1.2 rouge-score==0.1.2
# via lm-eval # via lm-eval
rpds-py==0.20.1 rpds-py==0.20.1
...@@ -660,6 +912,8 @@ rpds-py==0.20.1 ...@@ -660,6 +912,8 @@ rpds-py==0.20.1
# referencing # referencing
rsa==4.9.1 rsa==4.9.1
# via google-auth # via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.11.0 runai-model-streamer==0.11.0
# via -r requirements/test.in # via -r requirements/test.in
runai-model-streamer-s3==0.11.0 runai-model-streamer-s3==0.11.0
...@@ -677,21 +931,32 @@ safetensors==0.4.5 ...@@ -677,21 +931,32 @@ safetensors==0.4.5
# transformers # transformers
schemathesis==3.39.15 schemathesis==3.39.15
# via -r requirements/test.in # via -r requirements/test.in
scikit-image==0.25.2
# via albumentations
scikit-learn==1.5.2 scikit-learn==1.5.2
# via # via
# albumentations
# librosa # librosa
# lm-eval # lm-eval
# mlflow
# mteb # mteb
# sentence-transformers # sentence-transformers
scipy==1.13.1 scipy==1.13.1
# via # via
# albumentations
# bm25s # bm25s
# librosa # librosa
# mlflow
# mteb # mteb
# scikit-image
# scikit-learn # scikit-learn
# sentence-transformers # sentence-transformers
# statsmodels # statsmodels
# vocos # vocos
segmentation-models-pytorch==0.4.0
# via
# terratorch
# torchgeo
sentence-transformers==3.2.1 sentence-transformers==3.2.1
# via # via
# -r requirements/test.in # -r requirements/test.in
...@@ -700,21 +965,30 @@ sentencepiece==0.2.0 ...@@ -700,21 +965,30 @@ sentencepiece==0.2.0
# via mistral-common # via mistral-common
setuptools==77.0.3 setuptools==77.0.3
# via # via
# lightning-utilities
# mamba-ssm # mamba-ssm
# pytablewriter # pytablewriter
# torch # torch
# triton # triton
shapely==2.1.1
# via
# geopandas
# torchgeo
shellingham==1.5.4 shellingham==1.5.4
# via typer # via typer
six==1.16.0 six==1.16.0
# via # via
# junit-xml # junit-xml
# lightly
# opencensus # opencensus
# python-dateutil # python-dateutil
# rfc3339-validator # rfc3339-validator
# rouge-score # rouge-score
# segmentation-models-pytorch
smart-open==7.1.0 smart-open==7.1.0
# via ray # via ray
smmap==5.0.2
# via gitdb
sniffio==1.3.1 sniffio==1.3.1
# via # via
# anyio # anyio
...@@ -727,10 +1001,17 @@ soundfile==0.12.1 ...@@ -727,10 +1001,17 @@ soundfile==0.12.1
# librosa # librosa
soxr==0.5.0.post1 soxr==0.5.0.post1
# via librosa # via librosa
sqlalchemy==2.0.41
# via
# alembic
# mlflow
sqlitedict==2.1.0 sqlitedict==2.1.0
# via lm-eval # via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.46.2 starlette==0.46.2
# via # via
# fastapi
# schemathesis # schemathesis
# starlette-testclient # starlette-testclient
starlette-testclient==0.4.1 starlette-testclient==0.4.1
...@@ -751,18 +1032,29 @@ tenacity==9.0.0 ...@@ -751,18 +1032,29 @@ tenacity==9.0.0
# via # via
# lm-eval # lm-eval
# plotly # plotly
tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
terratorch==1.1rc2
# via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
# terratorch
tiktoken==0.7.0 tiktoken==0.7.0
# via # via
# lm-eval # lm-eval
# mistral-common # mistral-common
timm==1.0.11 timm==1.0.15
# via # via
# -r requirements/test.in # -r requirements/test.in
# open-clip-torch # open-clip-torch
# segmentation-models-pytorch
# terratorch
# torchgeo
tokenizers==0.21.1 tokenizers==0.21.1
# via # via
# -r requirements/test.in # -r requirements/test.in
...@@ -776,18 +1068,28 @@ torch==2.7.1+cu128 ...@@ -776,18 +1068,28 @@ torch==2.7.1+cu128
# -r requirements/test.in # -r requirements/test.in
# accelerate # accelerate
# bitsandbytes # bitsandbytes
# efficientnet-pytorch
# encodec # encodec
# fastsafetensors # fastsafetensors
# kornia
# lightly
# lightning
# lm-eval # lm-eval
# mamba-ssm # mamba-ssm
# mteb # mteb
# open-clip-torch # open-clip-torch
# peft # peft
# pretrainedmodels
# pytorch-lightning
# runai-model-streamer # runai-model-streamer
# segmentation-models-pytorch
# sentence-transformers # sentence-transformers
# tensorizer # tensorizer
# terratorch
# timm # timm
# torchaudio # torchaudio
# torchgeo
# torchmetrics
# torchvision # torchvision
# vector-quantize-pytorch # vector-quantize-pytorch
# vocos # vocos
...@@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128 ...@@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128
# -r requirements/test.in # -r requirements/test.in
# encodec # encodec
# vocos # vocos
torchgeo==0.7.0
# via terratorch
torchmetrics==1.7.4
# via
# lightning
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.22.1+cu128 torchvision==0.22.1+cu128
# via # via
# -r requirements/test.in # -r requirements/test.in
# lightly
# open-clip-torch # open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch
# terratorch
# timm # timm
# torchgeo
tqdm==4.66.6 tqdm==4.66.6
# via # via
# datasets # datasets
# evaluate # evaluate
# huggingface-hub # huggingface-hub
# lightly
# lightning
# lm-eval # lm-eval
# mteb # mteb
# nltk # nltk
# open-clip-torch # open-clip-torch
# peft # peft
# pqdm # pqdm
# pretrainedmodels
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers # sentence-transformers
# tqdm-multiprocess # tqdm-multiprocess
# transformers # transformers
...@@ -843,18 +1163,34 @@ typer==0.15.2 ...@@ -843,18 +1163,34 @@ typer==0.15.2
# via fastsafetensors # via fastsafetensors
types-python-dateutil==2.9.0.20241206 types-python-dateutil==2.9.0.20241206
# via arrow # via arrow
typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.12.2 typing-extensions==4.12.2
# via # via
# albumentations
# alembic
# fastapi
# graphene
# huggingface-hub # huggingface-hub
# librosa # librosa
# lightning
# lightning-utilities
# mistral-common # mistral-common
# mlflow-skinny
# mteb # mteb
# opentelemetry-api
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pqdm # pqdm
# pydantic # pydantic
# pydantic-core # pydantic-core
# pydantic-extra-types # pydantic-extra-types
# pytorch-lightning
# sqlalchemy
# torch # torch
# torchgeo
# typer # typer
# typeshed-client
# typing-inspection # typing-inspection
typing-inspection==0.4.1 typing-inspection==0.4.1
# via pydantic # via pydantic
...@@ -866,9 +1202,13 @@ urllib3==2.2.3 ...@@ -866,9 +1202,13 @@ urllib3==2.2.3
# via # via
# blobfile # blobfile
# botocore # botocore
# docker
# lightly
# requests # requests
# responses # responses
# tritonclient # tritonclient
uvicorn==0.35.0
# via mlflow-skinny
vector-quantize-pytorch==1.21.2 vector-quantize-pytorch==1.21.2
# via -r requirements/test.in # via -r requirements/test.in
virtualenv==20.31.2 virtualenv==20.31.2
...@@ -880,11 +1220,15 @@ wcwidth==0.2.13 ...@@ -880,11 +1220,15 @@ wcwidth==0.2.13
webcolors==24.11.1 webcolors==24.11.1
# via jsonschema # via jsonschema
werkzeug==3.1.3 werkzeug==3.1.3
# via schemathesis # via
# flask
# schemathesis
word2number==1.1 word2number==1.1
# via lm-eval # via lm-eval
wrapt==1.17.2 wrapt==1.17.2
# via smart-open # via smart-open
xarray==2025.7.1
# via rioxarray
xxhash==3.5.0 xxhash==3.5.0
# via # via
# datasets # datasets
...@@ -893,5 +1237,7 @@ yarl==1.17.1 ...@@ -893,5 +1237,7 @@ yarl==1.17.1
# via # via
# aiohttp # aiohttp
# schemathesis # schemathesis
zipp==3.23.0
# via importlib-metadata
zstandard==0.23.0 zstandard==0.23.0
# via lm-eval # via lm-eval
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.utils import set_default_torch_num_threads
from ....conftest import VllmRunner
def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data
def _run_test(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
prompt = [
{
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": generate_test_mm_data(),
} for _ in range(10)
]
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
task="embed",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
) -> None:
_run_test(
vllm_runner,
model,
)
...@@ -651,6 +651,8 @@ class ModelConfig: ...@@ -651,6 +651,8 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config() self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self.registry.supports_multimodal_raw_input(self.architectures))
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
...@@ -1243,10 +1245,10 @@ class ModelConfig: ...@@ -1243,10 +1245,10 @@ class ModelConfig:
return self.get_hf_config_sliding_window() return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int: def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size return getattr(self.hf_text_config, "vocab_size", 0)
def get_hidden_size(self) -> int: def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size return getattr(self.hf_text_config, "hidden_size", 0)
@property @property
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
......
...@@ -238,14 +238,14 @@ class LLMEngine: ...@@ -238,14 +238,14 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None self.tokenizer = None
self.detokenizer = None self.detokenizer = None
tokenizer_group = None tokenizer_group = None
else:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self, # Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues # to avoid engine GC issues
......
...@@ -136,6 +136,40 @@ def supports_multimodal( ...@@ -136,6 +136,40 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False) return getattr(model, "supports_multimodal", False)
@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...
@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...
def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
return getattr(model, "supports_multimodal_raw_input", False)
@runtime_checkable @runtime_checkable
class SupportsScoreTemplate(Protocol): class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template.""" """The interface required for all models that support score template."""
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Inference-only IBM/NASA Prithvi Geospatial model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Union from typing import Optional, Union
...@@ -27,13 +28,14 @@ from vllm.config import VllmConfig ...@@ -27,13 +28,14 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler) PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree, from vllm.model_executor.models.interfaces import (
SupportsMultiModal, IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
SupportsV0Only)
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs) MultiModalFieldElem, MultiModalInputs,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate) BaseProcessingInfo, PromptUpdate)
...@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder( ...@@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
# The size of pixel_values might change in the cases where we resize # The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below. # the input but never exceeds the dimensions below.
return { return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0), "pixel_values": torch.full((6, 512, 512), 1.0,
"location_coords": torch.full((1, 2), 1.0), dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
} }
...@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.shared(batch_size=1,
location_coords=MultiModalFieldConfig.batched("image"), modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
for k, v in mm_data.items(): for k, v in mm_data.items():
mm_kwargs[k] = v mm_kwargs[k] = v
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
multimodal_kwargs_items = [
MultiModalKwargsItem.from_elems([
MultiModalFieldElem(
modality="image",
key=key,
data=data,
field=MultiModalSharedField(1),
) for key, data in mm_kwargs.items()
])
]
return MultiModalInputs( return MultiModalInputs(
type="multimodal", type="multimodal",
prompt=prompt, prompt=prompt,
prompt_token_ids=[1], prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs), mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_hashes=None, mm_hashes=None,
mm_placeholders={}, mm_placeholders=mm_placeholders,
) )
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor, PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo, info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder) dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, )
SupportsV0Only): class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
"""Prithvi Masked Autoencoder""" """Prithvi Masked Autoencoder"""
is_pooling_model = True is_pooling_model = True
...@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, ...@@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]: def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model # We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask": if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask( task = SemanticSegmentationTask(
config["model_args"], config["model_args"],
config["task_args"]["model_factory"], config["task_args"]["model_factory"],
...@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, ...@@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
scheduler_hparams=config["scheduler_params"], scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"], plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"], freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"]) freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model return task.model
else: else:
...@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, ...@@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
def _parse_and_validate_multimodal_data( def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. " raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
location_coords = kwargs.pop("location_coords", None) location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor): if not isinstance(location_coords, torch.Tensor):
...@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, ...@@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return pixel_values, location_coords return pixel_values, location_coords
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return torch.empty((input_ids.shape[0], 0))
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
......
...@@ -22,8 +22,8 @@ from vllm.logger import init_logger ...@@ -22,8 +22,8 @@ from vllm.logger import init_logger
from .interfaces import (has_inner_state, has_noops, is_attention_free, from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_pp, supports_multimodal, supports_multimodal_raw_input,
supports_transcription, supports_v0_only) supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model from .interfaces_base import is_text_generation_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -287,6 +287,7 @@ class _ModelInfo: ...@@ -287,6 +287,7 @@ class _ModelInfo:
is_pooling_model: bool is_pooling_model: bool
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_pp: bool supports_pp: bool
has_inner_state: bool has_inner_state: bool
is_attention_free: bool is_attention_free: bool
...@@ -304,6 +305,7 @@ class _ModelInfo: ...@@ -304,6 +305,7 @@ class _ModelInfo:
is_pooling_model=True, # Can convert any model into a pooling model is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model), has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model), is_attention_free=is_attention_free(model),
...@@ -573,6 +575,13 @@ class _ModelRegistry: ...@@ -573,6 +575,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal return model_cls.supports_multimodal
def supports_multimodal_raw_input(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal_raw_input
def is_pp_supported_model( def is_pp_supported_model(
self, self,
architectures: Union[str, list[str]], architectures: Union[str, list[str]],
......
...@@ -266,7 +266,7 @@ class MultiModalRegistry: ...@@ -266,7 +266,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model") raise ValueError(f"{model_config.model} is not a multimodal model")
if tokenizer is None: if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None: if disable_cache is None:
mm_config = model_config.get_multimodal_config() mm_config = model_config.get_multimodal_config()
......
...@@ -94,11 +94,14 @@ class AsyncLLM(EngineClient): ...@@ -94,11 +94,14 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
# Tokenizer (+ ensure liveness if running in another process). if self.model_config.skip_tokenizer_init:
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = None
model_config=vllm_config.model_config, else:
scheduler_config=vllm_config.scheduler_config, # Tokenizer (+ ensure liveness if running in another process).
lora_config=vllm_config.lora_config) self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests). # Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor( self.processor = Processor(
...@@ -525,6 +528,10 @@ class AsyncLLM(EngineClient): ...@@ -525,6 +528,10 @@ class AsyncLLM(EngineClient):
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer: ) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer.get_lora_tokenizer(lora_request) return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
......
...@@ -82,11 +82,14 @@ class LLMEngine: ...@@ -82,11 +82,14 @@ class LLMEngine:
self.dp_group = None self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
# Tokenizer (+ ensure liveness if running in another process). if self.model_config.skip_tokenizer_init:
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = None
model_config=vllm_config.model_config, else:
scheduler_config=vllm_config.scheduler_config, # Tokenizer (+ ensure liveness if running in another process).
lora_config=vllm_config.lora_config) self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config, self.processor = Processor(vllm_config=vllm_config,
......
...@@ -327,14 +327,16 @@ class OutputProcessor: ...@@ -327,14 +327,16 @@ class OutputProcessor:
if request_id in self.request_states: if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.") raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request( tokenizer = None if not self.tokenizer else \
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), self.tokenizer.get_lora_tokenizer(request.lora_request)
request=request,
prompt=prompt, req_state = RequestState.from_new_request(tokenizer=tokenizer,
parent_req=parent_req, request=request,
request_index=request_index, prompt=prompt,
queue=queue, parent_req=parent_req,
log_stats=self.log_stats) request_index=request_index,
queue=queue,
log_stats=self.log_stats)
self.request_states[request_id] = req_state self.request_states[request_id] = req_state
self.lora_states.add_request(req_state) self.lora_states.add_request(req_state)
if parent_req: if parent_req:
......
...@@ -380,7 +380,6 @@ class Processor: ...@@ -380,7 +380,6 @@ class Processor:
prompt_type: Literal["encoder", "decoder"], prompt_type: Literal["encoder", "decoder"],
): ):
model_config = self.model_config model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
prompt_ids = prompt_inputs["prompt_token_ids"] prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids: if not prompt_ids:
...@@ -389,9 +388,14 @@ class Processor: ...@@ -389,9 +388,14 @@ class Processor:
else: else:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0) if self.model_config.skip_tokenizer_init:
if max_input_id > tokenizer.max_token_id: tokenizer = None
raise ValueError(f"Token id {max_input_id} is out of vocabulary") else:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len: if len(prompt_ids) > max_prompt_len:
......
...@@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_multimodal_model = model_config.is_multimodal_model self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None self.is_pooling_model = model_config.pooler_config is not None
self.model_supports_multimodal_raw_input = (
model_config.model_supports_multimodal_raw_input)
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
...@@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args: Args:
scheduler_output: The scheduler output. scheduler_output: The scheduler output.
""" """
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
self.attn_metadata_builders[0].reorder_batch(self.input_batch, self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output) scheduler_output)
...@@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata() self.input_batch.refresh_metadata()
def _init_model_kwargs_for_multimodal_model(
self,
scheduler_output: Optional["SchedulerOutput"] = None,
num_reqs: int = -1,
) -> dict[str, Any]:
model_kwargs: dict[str, Any] = {}
if self.model_supports_multimodal_raw_input:
# This model requires the raw multimodal data in input.
if scheduler_output:
multi_modal_kwargs_list = []
for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs
if not isinstance(req_mm_inputs, list):
req_mm_inputs = list(req_mm_inputs)
multi_modal_kwargs_list.extend(req_mm_inputs)
multi_modal_kwargs = MultiModalKwargs.batch(
multi_modal_kwargs_list)
else:
# The only case where SchedulerOutput is None is for
# a dummy run let's get some dummy data.
dummy_data = [
self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=1).multi_modal_data for i in range(num_reqs)
]
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
model_kwargs.update(multi_modal_kwargs)
return model_kwargs
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
num_tokens: np.ndarray, num_tokens: np.ndarray,
...@@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens] input_ids = self.input_ids[:num_scheduled_tokens]
model_kwargs = self._init_model_kwargs_for_multimodal_model(
scheduler_output=scheduler_output)
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids, input_ids=input_ids,
multimodal_embeddings=mm_embeds or None, multimodal_embeddings=mm_embeds or None,
) )
# TODO(woosuk): Avoid the copy. Optimize. # TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens] inputs_embeds = self.inputs_embeds[:num_input_tokens]
...@@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = {}
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens] positions = self.mrope_positions[:, :num_input_tokens]
else: else:
...@@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_kwargs,
device=self.device,
),
) )
self.maybe_wait_for_kv_save() self.maybe_wait_for_kv_save()
...@@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens): num_scheduled_tokens):
model = self.model model = self.model
if self.is_multimodal_model: if self.is_multimodal_model:
model_kwargs = self._init_model_kwargs_for_multimodal_model(
num_reqs=num_reqs)
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
else: else:
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = {}
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens] positions = self.mrope_positions[:, :num_tokens]
else: else:
...@@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_kwargs,
device=self.device,
),
) )
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, _ = outputs hidden_states, _ = outputs
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment