Unverified Commit 51d5e9be authored by mgazz's avatar mgazz Committed by GitHub
Browse files

[Core][Model] Terratorch backend integration (#23513)


Signed-off-by: default avatarMichele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarChristian Pinto <christian.pinto@ibm.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent e7fc7001
...@@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum): ...@@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum):
AUTO = "auto" AUTO = "auto"
VLLM = "vllm" VLLM = "vllm"
TRANSFORMERS = "transformers" TRANSFORMERS = "transformers"
TERRATORCH = "terratorch"
def get_attr_docs(cls: type[Any]) -> dict[str, str]: def get_attr_docs(cls: type[Any]) -> dict[str, str]:
...@@ -496,7 +497,9 @@ class ModelConfig: ...@@ -496,7 +497,9 @@ class ModelConfig:
back to the Transformers implementation if no vLLM implementation is back to the Transformers implementation if no vLLM implementation is
available.\n available.\n
- "vllm" will use the vLLM model implementation.\n - "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation.""" - "transformers" will use the Transformers model implementation.\n
- "terratorch" will use the TerraTorch model implementation.
"""
override_attention_dtype: Optional[str] = None override_attention_dtype: Optional[str] = None
"""Override dtype for attention""" """Override dtype for attention"""
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
......
...@@ -184,10 +184,11 @@ _EMBEDDING_MODELS = { ...@@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# Technically PrithviGeoSpatialMAE is a model that works on images, both in # Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggybacks on embedding # input and output. I am adding it here because it piggy-backs on embedding
# models for the time being. # models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
"Terratorch": ("terratorch", "Terratorch"),
} }
_CROSS_ENCODER_MODELS = { _CROSS_ENCODER_MODELS = {
...@@ -639,6 +640,9 @@ class _ModelRegistry: ...@@ -639,6 +640,9 @@ class _ModelRegistry:
model_info = self._try_inspect_model_cls(arch) model_info = self._try_inspect_model_cls(arch)
if model_info is not None: if model_info is not None:
return (model_info, arch) return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type) # Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
...@@ -687,6 +691,11 @@ class _ModelRegistry: ...@@ -687,6 +691,11 @@ class _ModelRegistry:
model_cls = self._try_load_model_cls(arch) model_cls = self._try_load_model_cls(arch)
if model_cls is not None: if model_cls is not None:
return (model_cls, arch) return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
arch = "Terratorch"
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
# Fallback to transformers impl (after resolving convert_type) # Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
......
...@@ -15,13 +15,16 @@ ...@@ -15,13 +15,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Wrapper around `Terratorch` models"""
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
InputDefinition, InputTypeEnum)
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler ...@@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig, MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems, MultiModalInputs, MultiModalKwargsItems,
...@@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings, ...@@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _terratorch_field_names(pretrained_cfg: dict):
# This model receives in input a multi-dimensional tensor representing input_definition = InputDefinition(**pretrained_cfg["input"])
# a single image patch and therefore it is not to be split return set(input_definition.data.keys())
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
return dict(
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): def _terratorch_field_factory(
pretrained_cfg: dict
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
def _parse_image_data( def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
self, input_definition = InputDefinition(**pretrained_cfg["input"])
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], fields = {}
) -> Optional[ModalityDataItems[Any, Any]]: for input_name, input in input_definition.data.items():
if isinstance(data, dict): if input.type == InputTypeEnum.tensor:
return DictEmbeddingItems( fields[input_name] = "image"
data,
modality="image",
required_fields={"pixel_values", "location_coords"},
fields_factory=_prithvi_field_config,
)
return super()._parse_image_data(data) mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality)
return mm_fields_config
return _terratorch_field_config
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
class TerratorchProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
class PrithviGeoSpatialMAEInputBuilder( class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
def __init__(self, info: TerratorchProcessingInfo):
super().__init__(info)
self.dummy_data_generator = DummyDataGenerator(
self.info.get_hf_config().to_dict()["pretrained_cfg"])
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return "" return ""
...@@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder( ...@@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalDataDict: ) -> MultiModalDataDict:
# This model input is fixed and is in the form of a torch Tensor. # Dummy data is generated based on the 'input' section
# The size of pixel_values might change in the cases where we resize # defined in the HF configuration file
# the input but never exceeds the dimensions below. return self.dummy_data_generator.get_dummy_mm_data()
image_data = {
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16), class TerratorchMultiModalDataParser(MultiModalDataParser):
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
} def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
super().__init__(*args, **kwargs)
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return {"image": image_data} terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
data,
modality="image",
required_fields=terratorch_fields,
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
)
return super()._parse_image_data(data)
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return PrithviGeoSpatialMAEMultiModalDataParser() return TerratorchMultiModalDataParser(
pretrained_cfg=self.pretrained_cfg)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return _prithvi_field_config(hf_inputs) return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
@default_pooling_type("All") @default_pooling_type("All")
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor, TerratorchMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo, info=TerratorchProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder, dummy_inputs=TerratorchInputBuilder,
) )
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
"""Prithvi Masked Autoencoder"""
supports_multimodal_raw_input_only = True supports_multimodal_raw_input_only = True
is_pooling_model = True is_pooling_model = True
...@@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
loss=config["task_args"]["loss"],
lr=config["task_args"]["lr"],
ignore_index=config["task_args"]["ignore_index"],
optimizer=config["task_args"]["optimizer"],
optimizer_hparams=config["optimizer_params"],
scheduler=config["task_args"]["scheduler"],
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
return None
def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
# the actual model is dynamically instantiated using terratorch config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
# allowing us to perform changes to the model architecture
# at startup time (e.g., change the model decoder class.) self.inference_runner = InferenceRunner(config)
self.model = self._instantiate_model( self.model = self.inference_runner.model
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
if self.model is None:
raise ValueError(
"Unsupported task. "
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
...@@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, ) {"encode": Pooler.for_encode(pooler_config)}, )
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
raise ValueError(f"Incorrect type of location_coords. "
f"Got type: {type(location_coords)}")
location_coords = torch.unbind(location_coords, dim=0)[0]
if location_coords.shape == torch.Size([0]):
location_coords = None
return pixel_values, location_coords
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
): ):
pixel_values, location_coords = ( model_output = self.inference_runner.forward(**kwargs)
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
location_coords=location_coords)
return model_output.output return model_output.output
...@@ -283,9 +257,12 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -283,9 +257,12 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
model_buffers = dict(self.named_buffers()) model_buffers = dict(self.named_buffers())
loaded_buffers = [] loaded_buffers = []
for key, value in weights: for key, value in weights:
if isinstance(value, (dict, OrderedDict)):
if key == "state_dict": if key == "state_dict":
weights_to_parse = value weights_to_parse = value
for name, weight in weights_to_parse.items(): for name, weight in weights_to_parse.items():
name = f"inference_runner.{name}"
if "pos_embed" in name: if "pos_embed" in name:
continue continue
...@@ -306,6 +283,9 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -306,6 +283,9 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
params_list.append((name, weight)) params_list.append((name, weight))
break break
elif isinstance(value, torch.Tensor):
params_list.append((f"inference_runner.model.{key}", value))
# Load the remaining model parameters # Load the remaining model parameters
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(params_list) autoloaded_weights = loader.load_weights(params_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment