"magic_pdf/vscode:/vscode.git/clone" did not exist on "3bd0ecf16655ee5774c7c089ec6e181d11dd8004"
Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Magma model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class MagmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MagmaModel`]. It is used to instantiate an Magma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Magma-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Magma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MagmaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Magma 1 supports up to 2048 tokens,
Magma 2 up to 4096, CodeMagma up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
```python
>>> from transformers import MagmaModel, MagmaConfig
>>> # Initializing a Magma magma-7b style configuration
>>> configuration = MagmaConfig()
>>> # Initializing a model from the magma-7b style configuration
>>> model = MagmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "magma"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vision_config=None,
text_config=None,
image_token_index=None,
tie_word_embeddings=False,
**kwargs,
):
self.vision_config = vision_config
self.image_token_index = image_token_index
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
if "model_type" in kwargs:
text_config = CONFIG_MAPPING[kwargs["model_type"]](**kwargs)
if text_config is not None:
# copy all variables in text_config to self
for key, value in text_config.__dict__.items():
if not key.startswith("_") and not key.startswith("__"):
setattr(self, key, value)
self.text_config = text_config
else:
self.text_config = None
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
\ No newline at end of file
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Magma."""
from typing import List, Optional, Union
import ast
import numpy as np
import torchvision
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
convert_to_rgb,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
make_list_of_images,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
from transformers import AutoImageProcessor
logger = logging.get_logger(__name__)
if is_vision_available():
from PIL import Image
import torch
import torchvision
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def process_anyres_image(image, max_num_crops=None, base_width=768, base_height=768):
"""
Process an image with variable resolutions.
Args:
image (torch.Tensor): The input image to be processed.
max_num_crops (int): Maximum number of crops
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
assert max_num_crops is not None
grid_pinpoints = []
for i in range(1, max_num_crops+1):
for j in range(1, max_num_crops // i + 1):
grid_pinpoints.append((i, j))
grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions)
# NOTE: reverse best_resolution from (width, height) to (height, width)
best_resolution = (best_resolution[1], best_resolution[0])
best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width)
# resize image tensor to best resolution
image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear')
# divide image tensor into patches
patches = image.unfold(2, base_height, base_height).unfold(3, base_width, base_width)
patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(best_resolution_grid[0]*best_resolution_grid[1], -1, base_height, base_width)
return (patches, best_resolution_grid)
def process_anyres_image_global(image, max_num_crops=None, base_width=768, base_height=768):
"""
Process an image with variable resolutions.
Args:
image (torch.Tensor): The input image to be processed.
max_num_crops (int): Maximum number of crops
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
assert max_num_crops is not None
grid_pinpoints = []
for i in range(1, max_num_crops+1):
for j in range(1, max_num_crops // i + 1):
grid_pinpoints.append((i, j))
grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions)
# NOTE: reverse best_resolution from (width, height) to (height, width)
best_resolution = (best_resolution[1], best_resolution[0])
best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width)
# resize image tensor to best resolution
image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear')
return image
class preprocessor():
def __init__(self, image_preprocessor, base_resolution=(256, 256)):
self.image_preprocessor = image_preprocessor
self.crop_size = {
'height': base_resolution[0],
'width': base_resolution[1]
}
self.image_mean = image_preprocessor.transforms[-1].mean
def preprocess(self, image, return_tensors='pt'):
image = self.image_preprocessor(image).unsqueeze(0)
return {
'pixel_values': image,
}
class MagmaImageProcessor(BaseImageProcessor):
r"""
Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
Args:
anyres_strategy (`str`):
strategy to cope with high-resolution images. one conventional way is multi-crop and many other works to accomadate clip-vit models.
however, since we are using convnext, which is essentially convnet, so we can use arbitary resolution images. as such, we use global strategy by defualt,
i.e., directly resize image holistically to a certain resolution.
base_img_size (int, *optional*, defaults to 768):
as convnext has 1/32 downsample rate, we use 768 as the base resolution so that the resulted feature map is 24x24.
num_crops (int, *optional*, defaults to 1):
number of effective crops when coping with images with higher resolution than 768x768. note that num_crops > 1 does not mean we are cropping the image.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
anyres_strategy: str = 'global',
base_img_size: int = 768,
num_crops: int = 1,
do_convert_rgb: bool = True,
image_mean: List[float] = OPENAI_CLIP_MEAN,
image_std: List[float] = OPENAI_CLIP_STD,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.base_img_size = base_img_size
self.anyres_strategy = anyres_strategy
self.num_crops = num_crops
self.do_convert_rgb = do_convert_rgb
self.image_mean = image_mean
self.image_std = image_std
def preprocess(
self,
images: Union[ImageInput, List[ImageInput]],
do_pad: bool = False,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
num_crops: int = None,
):
"""
Args:
images (`ImageInput` or `List[ImageInput]`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
"""
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# tensor transform and normalize
img_processor = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(self.image_mean, self.image_std)
])
images = [img_processor(image) for image in images]
image_data_type = 'half' if images[0].type() == 'torch.HalfTensor' else 'float'
images = [image.float() for image in images]
# crop images to the same size
image_patches = [process_anyres_image(image, self.num_crops if num_crops is None else num_crops, base_width=self.base_img_size, base_height=self.base_img_size) for image in images]
pixel_values = torch.cat([image[0] for image in image_patches], dim=0)
# pixel_values = [image[0] for image in image_patches]
image_sizes = [image_patch[1] for image_patch in image_patches]
if image_data_type == 'half':
pixel_values = pixel_values.half()
data = {
"pixel_values": pixel_values,
"image_sizes": image_sizes,
}
return BatchFeature(data=data, tensor_type=return_tensors)
AutoImageProcessor.register("MagmaImageProcessor", MagmaImageProcessor)
\ No newline at end of file
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Magma."""
from typing import List, Optional, Union
import logging
# Configure root logger
logging.basicConfig(level=logging.INFO)
import numpy as np
import torchvision
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
convert_to_rgb,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
make_list_of_images,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
# logging.set_verbosity_info()
logger = logging.get_logger(__name__)
if is_vision_available():
from PIL import Image
import torchvision
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import open_clip
from open_clip.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from pathlib import Path
from typing import Optional, Tuple, Type
from functools import partial
import torch.utils.checkpoint as checkpoint
from typing import Any, Dict, Optional, Tuple, Union
from dataclasses import asdict
HF_HUB_PREFIX = 'hf-hub:'
def _get_hf_config(model_id, cache_dir=None):
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_path_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None
if device == "auto":
device = {'': device}
else:
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logger.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
cache_dir=cache_dir,
)
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logger.info(f'Loaded {model_name} model config.')
else:
logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_path_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["timm_drop_path"] = force_path_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
# model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
if device != {'':'auto'}:
model.to(device=device, dtype=dtype)
else:
model.to(dtype=dtype)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
# else:
# model.to(device=device)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
# if checkpoint_path:
# logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
# open_clip.load_checkpoint(model, checkpoint_path)
# else:
# error_str = (
# f'Pretrained weights ({pretrained}) not found for model {model_name}.'
# f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
# logger.warning(error_str)
# raise RuntimeError(error_str)
# pretrained_loaded = True
elif has_hf_hub_prefix and require_pretrained:
logger.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
open_clip.load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
# set image preprocessing configuration in model attributes for convenience
if getattr(model.visual, 'image_size', None) is not None:
# use image_size set on model creation (via config or force_image_size arg)
force_preprocess_cfg['size'] = model.visual.image_size
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_path_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
return create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_path_dropout=force_path_dropout,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
**model_kwargs,
)
class D2CLIP_HF(nn.Module):
def __init__(self, config, **kwargs):
super().__init__()
self.model_name = config['vision_backbone']
require_pretrained = kwargs.get('require_pretrained', False)
if self.model_name == "convnextxxlarge":
clip_model = create_model_and_transforms('hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg', require_pretrained=require_pretrained)
elif self.model_name == "convnextlarge":
clip_model = create_model_and_transforms('hf-hub:laion/CLIP-convnext_large-laion2B-s34B-b82K-augreg', require_pretrained=require_pretrained)
self.clip_vision_model = clip_model.visual
model_name = self.model_name.lower()
assert 'convnext' in model_name, f"Only convnext backbone is supported for Magma model, but got {model_name}"
self.model_type = 'convnext'
if 'xxlarge' in model_name:
self.output_channels = [384, 384, 768, 1536, 3072]
elif 'large' in model_name:
self.output_channels = [192, 192, 384, 768, 1536]
elif 'base' in model_name:
self.output_channels = [128, 128, 256, 512, 1024]
self._out_feature_strides = {
"res2": 4,
"res3": 8,
"res4": 16,
"res5": 32,
}
self._out_feature_channels = {
"res2": self.output_channels[1],
"res3": self.output_channels[2],
"res4": self.output_channels[3],
"res5": self.output_channels[4],
}
def extract_features_convnext(self, x, gradient_checkpointing=True):
out = {}
x = self.clip_vision_model.trunk.stem(x)
if gradient_checkpointing:
x = checkpoint.checkpoint(self.clip_vision_model.trunk.stages, x)
else:
x = self.clip_vision_model.trunk.stages(x)
out['clip_vis_dense'] = x
return out
def forward(self, x, gradient_checkpointing=True):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
return self.extract_features_convnext(x, gradient_checkpointing=gradient_checkpointing)
@property
def size_divisibility(self):
return 32
class MagmaImageTower(D2CLIP_HF):
r"""
Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
Args:
config (dict): Configuration dictionary containing the keys for the image processor.
"""
def __init__(
self,
config,
**kwargs
) -> None:
super().__init__(config, **kwargs)
@property
def hidden_size(self):
return self.output_channels[-1]
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Args:
x (torch.Tensor): A tensor of shape (N, C, H, W) representing an image.
Returns:
torch.Tensor: A tensor of shape (N, C, H, W) representing the processed image.
"""
return super().forward(x)
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Magma model."""
import math
import re
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
import wandb
import torch.distributed as dist
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import ModelOutput
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers import AutoConfig, AutoModelForCausalLM
from .configuration_magma import MagmaConfig
from .image_tower_magma import MagmaImageTower
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MagmaConfig"
@dataclass
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Magma
class MagmaCausalLMOutputWithPast(ModelOutput):
"""
Base class for Magma causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class MagmaMultiModalProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
dim_vision = {'base': 640, 'large': 768, 'xxlarge': 1024}
vision_backbone = config.get('vision_backbone', 'convnextxxlarge')
vision_backbone_size = vision_backbone.replace('convnext', '')
projector_type = config.get('mm_projector_type', 'linear')
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config['mm_hidden_size'], config['hidden_size'])]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config['hidden_size'], config['hidden_size']))
self.proj = nn.Sequential(*modules)
# define a row seperator
self.row_seperator = nn.Parameter(torch.zeros(1, 1, config['hidden_size']))
if config.get('mm_use_im_start_end', False):
self.img_start_seperator = nn.Parameter(torch.zeros(1, config['hidden_size']))
self.img_end_seperator = nn.Parameter(torch.zeros(1, config['hidden_size']))
def forward(self, x):
return self.proj(x)
MAGMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MagmaConfig`] or [`MagmaVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
MAGMA_START_DOCSTRING,
)
class MagmaPreTrainedModel(PreTrainedModel):
config_class = MagmaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MagmaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
MAGMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`MagmaImageProcessor.__call__`] for details. [`MagmaProcessor`] uses
[`MagmaImageProcessor`] for processing images.
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
The sizes of the images in the batch, being (height, width) for each image.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"""The Magma model which consists of a vision backbone and a language model.""",
MAGMA_START_DOCSTRING,
)
class MagmaForCausalLM(MagmaPreTrainedModel):
def __init__(self, config: MagmaConfig):
super().__init__(config)
self.vision_tower = MagmaImageTower(config.vision_config, require_pretrained=False)
config.vision_config['mm_hidden_size'] = config.vision_config['mm_hidden_size'] \
if 'mm_hidden_size' in config.vision_config else self.vision_tower.hidden_size
config.vision_config['hidden_size'] = config.vision_config['hidden_size'] \
if 'hidden_size' in config.vision_config else self.config.text_config.hidden_size
self.multi_modal_projector = MagmaMultiModalProjector(config.vision_config)
self.vocab_size = config.text_config.vocab_size
if hasattr(config.text_config, 'auto_map'):
del config.text_config.auto_map
try:
self.language_model = AutoModelForCausalLM.from_config(
config.text_config,
# attn_implementation=config._attn_implementation,
trust_remote_code=True
)
except:
self.language_model = AutoModelForCausalLM.from_pretrained(
config.text_config._name_or_path,
# attn_implementation=config._attn_implementation,
trust_remote_code=True
)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
try:
if dist.get_rank() == 0:
wandb.init(project=os.environ['WANDB_PROJECT'])
except:
pass
self.post_init()
# def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
# import pdb; pdb.set_trace()
# kwargs["_from_auto"] = True
# return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def load_special_module_from_ckpt(self, ckpt_path, torch_dtype=None):
from deepspeed.runtime.zero import Init
from deepspeed import zero
# Defer initialization for ZeRO-3 compatibility
# with Init(data_parallel_group=None):
# # Initialize the special module
# self.vision_tower = MagmaImageTower(self.config.vision_config, require_pretrained=False)
# Load checkpoint weights into the special module
checkpoint = torch.load(ckpt_path, map_location='cpu')
state_dict = {k.replace('visual.', ''): v for k, v in checkpoint.items() if 'visual.' in k}
# Convert checkpoint weights to match model's parameter dtype
if torch_dtype is None:
model_dtype = next(self.vision_tower.clip_vision_model.parameters()).dtype
for k, v in state_dict.items():
state_dict[k] = v.to(model_dtype)
else:
for k, v in state_dict.items():
state_dict[k] = v.to(torch_dtype)
# Temporarily gather parameters for loading (if ZeRO-3 is active)
with zero.GatheredParameters(list(self.vision_tower.parameters()), modifier_rank=0):
# Load the state dictionary
self.vision_tower.clip_vision_model.load_state_dict(state_dict, strict=False)
# After loading, ensure the module is on the correct device
for param in self.vision_tower.parameters():
param.data = param.data.to(self.device).to(torch_dtype)
# import pdb; pdb.set_trace()
# If using a DeepSpeed engine, attach the updated module
if hasattr(self, "deepspeed_engine"):
self.deepspeed_engine.module = self
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_image_features(
self,
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids=None,
labels=None,
image_token_index=None,
ignore_index=-100,
):
"""
Merge input_ids with with image features into final embeddings
Args:
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
All vision vectors of all images in the batch
feature_lens (`torch.LongTensor` of shape `(num_images)`):
The length of visual embeddings of each image as stacked in `image_features`
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with visual embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with image token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
:abels need to be recalculated to support training (if provided)
image_token_index (`int`, *optional*)
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
ignore_index (`int`, *optional*)
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
Returns:
final_embedding, final_attention_mask, position_ids, final_labels
Explanation:
each image has variable length embeddings, with length specified by feature_lens
image_features is concatenation of all visual embed vectors
task: fill each <image> with the correct number of visual embeddings
Example:
X (5 patches), Y (3 patches), Z (8)
X, Y are in the same sequence (in-context learning)
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
Edge cases:
* If tokens are same but image token sizes are different, then cannot infer left or right padding
input_ids: [
a b c d X g h
i j Y k l m n
]
where X is 3 tokens while Y is 5, this mean after merge
if left-padding (batched generation)
input_ids should be: [
_ _ a b c d X X X g h
i j Y Y Y Y Y k l m n
]
elif (right padding) (training)
input_ids should be: [
a b c d X X X g h _ _
i j Y Y Y Y Y k l m n
]
"""
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
with torch.no_grad():
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
if feature_lens.sum() != num_image_features:
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
left_padding = False
elif not _left_padding and not _right_padding:
# both side is 1, so cannot tell
left_padding = self.padding_side == "left"
else:
# invalid attention_mask
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
final_labels = None
if labels is not None:
# NOTE: this is a bug in the original code!!!
final_labels = torch.full_like(final_attention_mask.long(), ignore_index).to(torch.long)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
with torch.no_grad():
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids, final_labels
@add_start_docstrings_to_model_forward(MAGMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MagmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: Union[torch.FloatTensor, List[torch.FloatTensor], List[List[torch.FloatTensor]]] = None,
image_sizes: Union[torch.LongTensor, List[torch.LongTensor], List[List[torch.LongTensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MagmaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, MagmaForConditionalGeneration
>>> model = MagmaForConditionalGeneration.from_pretrained("microsoft/magma-8b-hf")
>>> processor = AutoProcessor.from_pretrained("microsoft/magma-8b-hf")
>>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_config['vision_feature_layer']
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if inputs_embeds is None:
# 1. Extract the input embeddings
# In case image_token_index is not in the embeddings (extra token but embedding don't have it)
for_inputs_embeds_ids = input_ids.clone()
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1 and len(pixel_values) > 0:
# ! infer image_num_patches from image_sizes
if type(pixel_values) == list:
# nested list of pixel_values, each element is a list of pixel_values for each training instance, it could be multiple for video or interleaved setting
# e.g., pixel_values = [[img1, img2], [img1, img2, img3]]
n_imgs_per_sample = [len(pv) for pv in pixel_values]
pixels_values_list = sum(pixel_values, [])
image_sizes_list = sum(image_sizes, [])
else:
image_num_patches = [(imsize[imsize.sum(1) > 0,0] * imsize[imsize.sum(1) > 0,1]).tolist() for imsize in image_sizes]
# image_num_patches = [(imsize[:,0]*imsize[:,1]).tolist() for imsize in image_sizes]
# figure out if pixel_values is concatenated or stacked
if pixel_values.dim() == 5:
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:sum(num_patch)].split(num_patch, dim=0) for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
_image_sizes_list = [image_size[image_size.sum(-1) > 0].tolist() for image_size in image_sizes]
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
if self.config.vision_config['img_anyres_strategy'] == "global":
selected_image_features = []
# NOTE: both _image_sizes_list and _pixel_values_list are lists of lists, each item represents an training instance with one or multiple images
for idx, (image_size_for_instance, pixel_values_for_instance) in enumerate(zip(_image_sizes_list, _pixel_values_list)):
assert len(image_size_for_instance) == len(pixel_values_for_instance), f"{len(image_size_for_instance)} != {len(pixel_values_for_instance)}"
for image_size, pixel_values_for_image in zip(image_size_for_instance, pixel_values_for_instance):
pixel_values_for_image = pixel_values_for_image.view(image_size[0], image_size[1], *pixel_values_for_image.shape[1:])
pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
image_features = self.vision_tower(pixel_values_for_image)
selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
selected_image_feature = self.multi_modal_projector(selected_image_feature)
selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
selected_image_features.append(selected_image_feature.flatten(0, 1))
elif self.config.vision_config['img_anyres_strategy'] == "crop":
# calculate number of crops for each instance in the batch given _image_sizes_list
_image_sizes_list_temp = sum(_image_sizes_list, [])
# concate nate all images in _pixel_values_list
_pixel_values_list_temp = sum(_pixel_values_list, ())
_pixel_values_list_temp = torch.cat(_pixel_values_list_temp, dim=0)
image_features = self.vision_tower(_pixel_values_list_temp)[vision_feature_layer].permute(0, 2, 3, 1)
image_features = self.multi_modal_projector(image_features)
num_crops_list = [_image_size[0]*_image_size[1] for _image_size in _image_sizes_list_temp]
image_features_split = torch.split(image_features, num_crops_list, dim=0)
selected_image_features = []
for image_feature, image_size in zip(image_features_split, _image_sizes_list_temp):
image_feature = image_feature.view(image_size[0], image_size[1], *image_feature.shape[1:])
image_feature = image_feature.permute(0, 2, 1, 3, 4).flatten(2, 3).flatten(0, 1)
image_feature = torch.cat((image_feature, self.multi_modal_projector.row_seperator.repeat(image_feature.shape[0],1,1)), dim=1)
selected_image_features.append(image_feature.flatten(0, 1))
# raise NotImplementedError("crop strategy is not implemented yet")
# image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features[vision_feature_layer]
# image_features = torch.split(image_features, image_num_patches, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
feature_lens = [elem.shape[0] for elem in selected_image_features]
image_features = torch.cat(selected_image_features, 0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
# inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
# pixel_values is not None but is empty ---> text only cases
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
# there are no images
pass
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model.model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
loss = None
if labels is not None and self.training:
valid_mask = labels[..., 1:] != -100
shift_logits = self.language_model.lm_head(hidden_states[:,:-1][valid_mask]).contiguous()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
logits = shift_logits # dummy logits
shift_labels = labels[..., 1:][valid_mask].contiguous()
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
# localize the positions for shift_labels where the id is in betweek [config.tokenizer_vocab_size-256, config.tokenizer_vocab_size]
valid_indices = (shift_labels<self.config.tokenizer_vocab_size) & (shift_labels>=self.config.tokenizer_vocab_size-256)
if valid_indices.sum() > 0:
action_labels = shift_labels[valid_indices]
action_logits = shift_logits[valid_indices]
# calcualte the accuracy
action_accuracy = (action_logits.argmax(-1) == action_labels).float().mean()
# log the action accuracy
else:
action_accuracy = torch.tensor(0.0).to(shift_logits.device)
# torch distributed gather the action accuracy across all devices
action_accuracy = action_accuracy.unsqueeze(0)
# gather the action accuracy across all devices
action_accuracy_gather = [torch.zeros_like(action_accuracy) for _ in range(dist.get_world_size())]
dist.all_gather(action_accuracy_gather, action_accuracy)
# concatenate the action accuracy across all devices
action_accuracy = torch.cat(action_accuracy_gather)
if dist.get_rank() == 0:
# remove zero values
if action_accuracy.mean() == 0:
wandb.log({"action_accuracy": action_accuracy.mean().item()})
else:
action_accuracy = action_accuracy[action_accuracy != 0]
wandb.log({"action_accuracy": action_accuracy.mean().item()})
else:
logits = self.language_model.lm_head(hidden_states)
logits = logits.float()
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MagmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_sizes": image_sizes,
}
)
return model_inputs
def _reorder_cache(self, *args, **kwargs):
return self.language_model._reorder_cache(*args, **kwargs)
@add_start_docstrings(
"""The Magma model which consists of a vision backbone and a language model.""",
MAGMA_START_DOCSTRING,
)
class MagmaForConditionalGeneration(MagmaPreTrainedModel):
def __init__(self, config: MagmaConfig):
super().__init__(config)
self.vision_tower = MagmaImageTower(config.vision_config, require_pretrained=('magma' not in config.name_or_path))
self.multi_modal_projector = MagmaMultiModalProjector(config.vision_config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config,
# attn_implementation=config._attn_implementation,
trust_remote_code=True
)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_image_features(
self,
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids=None,
labels=None,
image_token_index=None,
ignore_index=-100,
):
"""
Merge input_ids with with image features into final embeddings
Args:
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
All vision vectors of all images in the batch
feature_lens (`torch.LongTensor` of shape `(num_images)`):
The length of visual embeddings of each image as stacked in `image_features`
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with visual embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with image token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
:abels need to be recalculated to support training (if provided)
image_token_index (`int`, *optional*)
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
ignore_index (`int`, *optional*)
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
Returns:
final_embedding, final_attention_mask, position_ids, final_labels
Explanation:
each image has variable length embeddings, with length specified by feature_lens
image_features is concatenation of all visual embed vectors
task: fill each <image> with the correct number of visual embeddings
Example:
X (5 patches), Y (3 patches), Z (8)
X, Y are in the same sequence (in-context learning)
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
Edge cases:
* If tokens are same but image token sizes are different, then cannot infer left or right padding
input_ids: [
a b c d X g h
i j Y k l m n
]
where X is 3 tokens while Y is 5, this mean after merge
if left-padding (batched generation)
input_ids should be: [
_ _ a b c d X X X g h
i j Y Y Y Y Y k l m n
]
elif (right padding) (training)
input_ids should be: [
a b c d X X X g h _ _
i j Y Y Y Y Y k l m n
]
"""
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
with torch.no_grad():
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
if feature_lens.sum() != num_image_features:
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
left_padding = False
elif not _left_padding and not _right_padding:
# both side is 1, so cannot tell
left_padding = self.padding_side == "left"
else:
# invalid attention_mask
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
final_labels = None
if labels is not None:
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
with torch.no_grad():
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids, final_labels
@add_start_docstrings_to_model_forward(MAGMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MagmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MagmaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, MagmaForConditionalGeneration
>>> model = MagmaForConditionalGeneration.from_pretrained("microsoft/magma-8b-hf")
>>> processor = AutoProcessor.from_pretrained("microsoft/magma-8b-hf")
>>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_config['vision_feature_layer']
)
if inputs_embeds is None:
# 1. Extract the input embeddings
# In case image_token_index is not in the embeddings (extra token but embedding don't have it)
for_inputs_embeds_ids = input_ids.clone()
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
# ! infer image_num_patches from image_sizes
# figure out if pixel_values is concatenated or stacked
if pixel_values.dim() == 5:
image_num_patches = [(imsize[:,0]*imsize[:,1]).tolist() for imsize in image_sizes]
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
if self.config.vision_config['img_anyres_strategy'] == "global":
num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
pixel_values_for_images = pixel_values.split(num_patches_for_images, dim=0)
selected_image_features = []
for idx, (image_size, pixel_values_for_image) in enumerate(zip(image_sizes, pixel_values_for_images)):
pixel_values_for_image = pixel_values_for_image.view(image_size[0], image_size[1], *pixel_values_for_image.shape[1:])
pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
image_features = self.vision_tower(pixel_values_for_image)
selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
selected_image_feature = self.multi_modal_projector(selected_image_feature)
selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
selected_image_features.append(selected_image_feature)
elif self.config.vision_config['img_anyres_strategy'] == "crop":
image_features = self.vision_tower(pixel_values)[vision_feature_layer].permute(0, 2, 3, 1)
image_features = self.multi_modal_projector(image_features)
num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
image_features_split = torch.split(image_features, num_patches_for_images, dim=0)
selected_image_features = []
for image_feature, image_size in zip(image_features_split, image_sizes):
image_feature = image_feature.view(image_size[0], image_size[1], *image_feature.shape[1:])
image_feature = image_feature.permute(0, 2, 1, 3, 4).flatten(2, 3).flatten(0, 1)
image_feature = torch.cat((image_feature, self.multi_modal_projector.row_seperator.repeat(image_feature.shape[0],1,1)), dim=1)
selected_image_features.append(image_feature)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
feature_lens = [elem.shape[0]*elem.shape[1] for elem in selected_image_features]
image_features = torch.cat([elem.flatten(0, 1) for elem in selected_image_features], 0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
# inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
# pixel_values is not None but is empty ---> text only cases
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
# there are no images
pass
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs[0]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MagmaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_sizes": image_sizes,
}
)
return model_inputs
def _reorder_cache(self, *args, **kwargs):
return self.language_model._reorder_cache(*args, **kwargs)
AutoConfig.register("magma", MagmaConfig)
AutoModelForCausalLM.register(MagmaConfig, MagmaForConditionalGeneration)
\ No newline at end of file
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Magma.
"""
from typing import List, Optional, Union
import transformers
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy
from transformers.utils import TensorType
from .configuration_magma import MagmaConfig
class MagmaProcessor(ProcessorMixin):
r"""
Constructs a Magma processor which wraps a Magma image processor and a LLaMa tokenizer into a single processor.
[`MagmaProcessor`] offers all the functionalities of [`MagmaImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~MagmaProcessor.__call__`] and [`~MagmaProcessor.decode`] for more information.
Args:
image_processor ([`MagmaImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None):
# super().__init__(image_processor, tokenizer)
self.image_processor = image_processor
self.tokenizer = tokenizer
def __call__(
self,
texts: Union[TextInput, List[TextInput]],
images: Union[ImageInput, List[ImageInput]],
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
do_pad: Optional[bool] = False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
MagmaImageProcessor's [`~MagmaImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
texts (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is not None:
image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors)
else:
image_inputs = {}
text_inputs = self.tokenizer(
texts, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
)
return BatchFeature(data={**text_inputs, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
---
library_name: transformers
pipeline_tag: image-text-to-text
license: mit
---
# Model Card for Magma-8B
<!-- Provide a quick summary of what the model is/does. -->
<div align="center">
<h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
[Jianwei Yang](https://jwyang.github.io/)<sup>*</sup><sup>1</sup><sup></sup>&nbsp;
[Reuben Tan](https://cs-people.bu.edu/rxtan/)<sup>1</sup><sup></sup>&nbsp;
[Qianhui Wu](https://qianhuiwu.github.io/)<sup>1</sup><sup></sup>&nbsp;
[Ruijie Zheng](https://ruijiezheng.com/)<sup>2</sup><sup></sup>&nbsp;
[Baolin Peng](https://scholar.google.com/citations?user=u1CNjgwAAAAJ&hl=en&oi=ao)<sup>1</sup><sup></sup>&nbsp;
[Yongyuan Liang](https://cheryyunl.github.io)<sup>2</sup><sup></sup>
[Yu Gu](http://yu-gu.me/)<sup>1</sup>&nbsp;
[Mu Cai](https://pages.cs.wisc.edu/~mucai/)<sup>3</sup>&nbsp;
[Seonghyeon Ye](https://seonghyeonye.github.io/)<sup>4</sup>&nbsp;
[Joel Jang](https://joeljang.github.io/)<sup>5</sup>&nbsp;
[Yuquan Deng](https://scholar.google.com/citations?user=LTC0Q6YAAAAJ&hl=en)<sup>5</sup>&nbsp;
[Lars Liden](https://sites.google.com/site/larsliden)<sup>1</sup>&nbsp;
[Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/)<sup>1</sup><sup></sup>
<sup>1</sup> Microsoft Research; <sup>2</sup> University of Maryland; <sup>3</sup> University of Wisconsin-Madison
<sup>4</sup> KAIST; <sup>5</sup> University of Washington
<sup>*</sup> Project lead <sup></sup> First authors <sup></sup> Second authors <sup></sup> Leadership
\[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] &nbsp; \[[Project Page](https://microsoft.github.io/Magma/)\] &nbsp; \[[Hugging Face Paper](https://huggingface.co/papers/2502.13130)\] &nbsp; \[[Github Repo](https://github.com/microsoft/Magma)\] &nbsp; \[[Video](https://www.youtube.com/watch?v=SbfzvUU5yM8)\]
</div>
## Agents
### UI Navigation
<div align="center">
<div align="center" style="display: inline-block; width: 48%;">
<video autoplay muted loop controls playsinline style="margin-bottom: 2px;">
<source src="https://microsoft.github.io/Magma/static/videos/ui_weather_and_flight_mode.mp4" type="video/mp4">
</video>
<p class="is-5 has-text-centered" style="font-size: 14px;">What's weather in Seattle? & turn on flight mode</p>
</div>
<div align="center" style="display: inline-block; width: 48%;">
<video autoplay muted loop controls playsinline style="margin-bottom: 2px;">
<source src="https://microsoft.github.io/Magma/static/videos/ui_wordle.mp4" type="video/mp4">
</video>
<p class="is-5 has-text-centered" style="font-size: 14px;">Share and message this to Bob Steve. Click send button</p>
</div>
</div>
### Robot Manipulation
<div align="center">
<div align="center">
<div style="display: flex; justify-content: space-between; gap: 1%;">
<div style="width: 32%;">
<video autoplay muted loop controls playsinline height="98%" style="max-width: 450px; width: 100%; border-radius: 10px; overflow: hidden; margin-bottom: 5px;">
<source src="https://microsoft.github.io/Magma/static/videos/magma_hotdog.mp4" type="video/mp4">
</video>
</div>
<div style="width: 32%;">
<video autoplay muted loop controls playsinline height="98%" style="max-width: 450px; width: 100%; border-radius: 10px; overflow: hidden; margin-bottom: 5px;">
<source src="https://microsoft.github.io/Magma/static/videos/magma_mushroom.mp4" type="video/mp4">
</video>
</div>
<div style="width: 32%;">
<video autoplay muted loop controls playsinline height="98%" style="max-width: 450px; width: 100%; border-radius: 10px; overflow: hidden; margin-bottom: 5px;">
<source src="https://microsoft.github.io/Magma/static/videos/magma_left.mp4" type="video/mp4">
</video>
</div>
</div>
</div>
<div align="center">
<div style="display: flex; justify-content: space-between; gap: 1%;">
<div style="width: 32%;">
<p style="text-align: center;font-size: 14px;margin-top: 0;">Pick Place Hotdog Sausage</p>
</div>
<div style="width: 32%;">
<p style="text-align: center;font-size: 14px;margin-top: 0;">Put Mushroom Place Pot</p>
</div>
<div style="width: 32%;">
<p style="text-align: center;font-size: 14px;margin-top: 0;">Push Cloth Left to Right (Out-of-Dist.)</p>
</div>
</div>
</div>
</div>
### Gaming
Task: Model controls the robot to collect green blocks.
<div align="center">
<div align="center" style="display: inline-block; width: 48%;">
<video autoplay muted loop controls playsinline style="margin-bottom: 2px;">
<source src="https://microsoft.github.io/Magma/static/videos/magma_vs_llava.mp4" type="video/mp4">
</video>
<p class="is-5 has-text-centered" style="font-size: 14px;">Magma v.s. LLaVA-OneVision</p>
</div>
<div align="center" style="display: inline-block; width: 48%;">
<video autoplay muted loop controls playsinline style="margin-bottom: 2px;">
<source src="https://microsoft.github.io/Magma/static/videos/magma_vs_gpt4omini.mp4" type="video/mp4">
</video>
<p class="is-5 has-text-centered" style="font-size: 14px;">Magma v.s. GPT4o-minni</p>
</div>
</div>
## Model Details
<div align="center">
<img src="https://github.com/microsoft/Magma/blob/main/assets/images/magma_teaser.png?raw=true" width="100%">
</div>
### Model Description
<!-- Provide a longer summary of what this model is. -->
Magma is a multimodal agentic AI model that can generate text based on the input text and image. The model is designed for research purposes and aimed at knowledge-sharing and accelerating research in multimodal AI, in particular the multimodal agentic AI. The main innovation of this model lies on the introduction of two technical innovations: **Set-of-Mark** and **Trace-of-Mark**, and the leverage of a **large amount of unlabeled video data** to learn the spatial-temporal grounding and planning. Please refer to our paper for more technical details.
### Highlights
* **Digital and Physical Worlds:** Magma is the first-ever foundation model for multimodal AI agents, designed to handle complex interactions across both virtual and real environments!
* **Versatile Capabilities:** Magma as a single model not only possesses generic image and videos understanding ability, but also generate goal-driven visual plans and actions, making it versatile for different agentic tasks!
* **State-of-the-art Performance:** Magma achieves state-of-the-art performance on various multimodal tasks, including UI navigation, robotics manipulation, as well as generic image and video understanding, in particular the spatial understanding and reasoning!
* **Scalable Pretraining Strategy:** Magma is designed to be **learned scalably from unlabeled videos** in the wild in addition to the existing agentic data, making it strong generalization ability and suitable for real-world applications!
## License
The model is developed by Microsoft and is funded by Microsoft Research. The model is shared by Microsoft Research and is licensed under the MIT License.
<!-- {{ model_description | default("", true) }}
- **Developed by:** {{ developers | default("[More Information Needed]", true)}}
- **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}}
- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}}
- **Model type:** {{ model_type | default("[More Information Needed]", true)}}
- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}}
- **License:** {{ license | default("[More Information Needed]", true)}}
- **Finetuned from model [optional]:** {{ base_model | default("[More Information Needed]", true)}} -->
## How to Get Started with the Model
<!-- {{ get_started_code | default("[More Information Needed]", true)}} -->
To get started with the model, you first need to make sure that `transformers` and `torch` are installed, as well as installing the following dependencies:
```bash
pip install torchvision Pillow open_clip_torch
```
⚠️ Please note that you need to install our customized transformers lib:
```bash
pip install git+https://github.com/jwyang/transformers.git@dev/jwyang-v4.48.2
```
See [here](https://github.com/microsoft/Magma?tab=readme-ov-file#installation) for the reason why you need this.
Then you can run the following code:
```python
import torch
from PIL import Image
from io import BytesIO
import requests
from transformers import AutoModelForCausalLM, AutoProcessor
# Load the model and processor
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained("microsoft/Magma-8B", trust_remote_code=True, torch_dtype=dtype)
processor = AutoProcessor.from_pretrained("microsoft/Magma-8B", trust_remote_code=True)
model.to("cuda")
# Inference
url = "https://assets-c4akfrf5b4d3f4b7.z01.azurefd.net/assets/2024/04/BMDataViz_661fb89f3845e.png"
image = Image.open(BytesIO(requests.get(url, stream=True).content))
image = image.convert("RGB")
convs = [
{"role": "system", "content": "You are agent that can see, talk and act."},
{"role": "user", "content": "<image_start><image><image_end>\nWhat is in this image?"},
]
prompt = processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
inputs = processor(images=[image], texts=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
inputs = inputs.to("cuda").to(dtype)
generation_args = {
"max_new_tokens": 128,
"temperature": 0.0,
"do_sample": False,
"use_cache": True,
"num_beams": 1,
}
with torch.inference_mode():
generate_ids = model.generate(**inputs, **generation_args)
generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
response = processor.decode(generate_ids[0], skip_special_tokens=True).strip()
print(response)
```
## Training Details
### Training Data
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
<!-- {{ training_data | default("[More Information Needed]", true)}} -->
Our training data consists of:
* Generic Image SFT Data: [LLaVA-Next](https://llava-vl.github.io/blog/2024-01-30-llava-next/), [InfoGrpahicVQA](https://www.docvqa.org/datasets/infographicvqa), [ChartQA_Augmented](https://github.com/vis-nlp/ChartQA), [FigureQA](https://www.microsoft.com/en-us/research/project/figureqa-dataset/), [TQA](https://paperswithcode.com/dataset/tqa), [ScienceQA](https://scienceqa.github.io/).
* Generic Video SFT Data: [ShareGPT4Video](https://sharegpt4video.github.io/) and [LLaVA-Video](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K).
* Instructional Video Data: [Ego4d](https://ego4d-data.org/), [Somethingv2](https://www.qualcomm.com/developer/software/something-something-v-2-dataset), [Epic-Kitchen](https://epic-kitchens.github.io/2025) and other related instructional videos.
* Robotics Manipulation Data: [Open-X-Embodiment](https://robotics-transformer-x.github.io/).
* UI Grounding Data: [SeeClick](https://github.com/njucckevin/SeeClick).
* UI Navigation Data: [Mind2web](https://osu-nlp-group.github.io/Mind2Web/) and [AITW](https://github.com/google-research/google-research/tree/master/android_in_the_wild).
The data collection process involved sourcing information from publicly available documents, with a meticulous approach to filtering out undesirable documents and images. To safeguard privacy, we carefully filtered various image and text data sources to remove or scrub any potentially personal data from the training data.
More details can be found in our paper.
[Microsoft Privacy Notice](https://go.microsoft.com/fwlink/?LinkId=521839)
### Training Procedure
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
#### Preprocessing
<!-- {{ preprocessing | default("[More Information Needed]", true)}} -->
In addition to the text-related preprocessing, we mainly undertake the following image and video preprocessing steps:
* UI Grounding and Navigation Data: For each UI screenshot, we extract the bounding boxes for the UI elements, and apply [Set-of-Mark Prompting](https://arxiv.org/abs/2310.11441) to overlay numeric marks on the raw image. The model is trained to generate the UI grounding text based on the image and the Set-of-Mark prompts.
* Instruction Video Data: For each video clip, we apply [Co-Tracker](https://co-tracker.github.io/) to extract the grid traces and then apply filtering algorithm to remove the noisy or static points. For videos that bear camera motion, we further apply homography transformation to stabilize the video clips. In the end, we assign a numeric mark for each trace which gives us a set of trace-of-mark. The model is trained to generate the trace-of-mark given the video clips and instructional text.
* Robotics Manipulation Data: For robotics data in Open-X Embodiment, we extract the 7 DoF robot gripper state and also extract the trace-of-mark from the video clips. Similar filtering and stabilization steps are applied to the video clips. The model is trained to generate the robot manipulation action as well as the trace-of-mark given the video clips and instructional text.
After all these preprocessing, we combine them with existing text annotations to form our final multimodal training data. We refer to our paper for more technical details.
#### Training Hyperparameters
<!-- - **Training regime:** {{ training_regime | default("[More Information Needed]", true)}} fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
We used bf16 mixed precision for training on H100s and MI300s. We used the following hyperparameters for training:
* Batch size: 1024
* Learning rate: 1e-5
* Max sequence length: 4096
* Resolution: maximally 1024x1024 for image, 512x512 for video frame.
* Pretraining Epochs: 3
## Evaluation
<!-- This section describes the evaluation protocols and provides the results. -->
We evaluate the model in zero-shot manner on a wide range of tasks, mostly agent-related tasks.
### Testing Data, Factors & Metrics
<!-- This should link to a Dataset Card if possible. -->
<!-- {{ testing_data | default("[More Information Needed]", true)}} -->
<!-- #### Factors
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
<!-- {{ testing_factors | default("[More Information Needed]", true)}} -->
#### Zero-shot Testing Data
We evaluate the model's zero-shot performance on the following datasets:
* UI Grounding: [ScreenSpot](https://huggingface.co/datasets/rootsautomation/ScreenSpot) and [VisualWebArena](https://jykoh.com/vwa).
* Robotics Manipulation: [SimplerEnv](https://jykoh.com/vwa) and WidowX real robot.
* Spatial Understanding and Reasoning: [VSR](https://github.com/cambridgeltl/visual-spatial-reasoning), [BLINK](https://zeyofu.github.io/blink/) and [SpatialEval](https://spatialeval.github.io/).
#### Finetuned Testing Data
We evaluate the model's performance after finetuning on the following datasets:
* UI Navigation: [Mind2Web](https://osu-nlp-group.github.io/Mind2Web/) and [AITW](https://github.com/google-research/google-research/tree/master/android_in_the_wild).
* Robotics Manipulation: [SimplerEnv](https://github.com/simpler-env/SimplerEnv) and WidowX real robot.
* Multimodal Image Understanding and Reasoning: [VQAv2](https://visualqa.org/), [GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html), [MME](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation), [POPE](https://huggingface.co/datasets/lmms-lab/POPE), [TextVQA](https://textvqa.org/), [ChartQA](https://github.com/vis-nlp/ChartQA), [DocVQA](https://www.docvqa.org/).
* Multimodal Video Understanding and Reasoning: [Next-QA](https://github.com/doc-doc/NExT-QA), [VideoMME](https://video-mme.github.io/home_page.html), [MVBench](https://huggingface.co/datasets/OpenGVLab/MVBench).
#### Metrics
<!-- {{ testing_metrics | default("[More Information Needed]", true)}} -->
We follow the individual dataset's evaluation metrics for the evaluation. Please refer to the original dataset for more details.
### Results on Agentic Intelligence
Zero-shot evaluation on agentic intelligence. We report the results for pretrained Magma without any domain-specific finetuning. Magma is the only model that can conduct the full task spectrum.
| Model | VQAv2 | TextVQA | POPE | SS-Mobile | SS-Desktop | SS-Web | VWB-Ele-G | VWB-Act-G | SE-Google Robot | SE-Bridge |
|-----------------------|------|--------|------|----------|-----------|------|----------|----------|---------------|-----------|
| GPT-4V | 77.2 | 78.0 | n/a | 23.6 | 16.0 | 9.0 | 67.5 | 75.7 | - | - |
| GPT-4V-OmniParser | n/a | n/a | n/a | 71.1 | 45.6 | 58.5 | - | - | - | - |
| LLava-1.5 | 78.5 | 58.2 | 85.9 | - | - | - | 12.1 | 13.6 | - | - |
| LLava-Next | 81.3 | 64.9 | 86.5 | - | - | - | 15.0 | 8.7 | - | - |
| Qwen-VL | 78.8 | 63.8 | n/a | 6.2 | 6.3 | 3.0 | 14.0 | 0.7 | - | - |
| Qwen-VL-Chat | 78.2 | 61.5 | n/a | - | - | - | - | - | - | - |
| Fuyu | 74.2 | n/a | n/a | 21.2 | 20.8 | 19.2 | 19.4 | 15.5 | - | - |
| SeeClick | - | - | - | 65.0 | 51.1 | 44.1 | 9.9 | 1.9 | - | - |
| Octo | - | - | - | - | - | - | - | - | - | - |
| RT-1-X | - | - | - | - | - | - | - | - | 6.0 | 15.9 |
| OpenVLA | - | - | - | - | - | - | - | - | 34.2 | 1.1 |
| Magma-8B | 80.0 | 66.5 | 87.4 | 59.5 | 64.1 | 60.6 | 96.3 | 71.8 | 52.3 | 35.4 |
*Notes: SS - ScreenSpot, VWB - VisualWebArena, SE - SimplerEnv*
<!-- {{ results | default("[More Information Needed]", true)}} -->
<!-- {{ results_summary | default("", true) }} -->
## Technical Specifications
### Model Architecture and Objective
<!-- {{ model_specs | default("[More Information Needed]", true)}} -->
* Language Model: We use [Meta LLama-3](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) as the backbone LLM.
* Vision Encoder: We use [CLIP-ConvneXt-XXLarge](https://huggingface.co/laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg) trained by LAION team as the vision encoder to tokenize the images and videos.
The whole pipeline follows the common practice in the multimodal LLMs, where the vision encoder is used to tokenize the images and videos, and then the visual tokens are fed into the LLM along with the textual tokens to generate the text outputs.
### Compute Infrastructure
<!-- {{ compute_infrastructure | default("[More Information Needed]", true)}} -->
We used [Azure ML](https://azure.microsoft.com/en-us/products/machine-learning) for our model training.
#### Hardware
<!-- {{ hardware_requirements | default("[More Information Needed]", true)}} -->
Our model is trained on two GPUs:
* Nvidia H100
* AMD MI300
#### Software
<!-- {{ software | default("[More Information Needed]", true)}} -->
Our model is built based on:
* [Pytorch](https://pytorch.org/)
* [Transformers](https://huggingface.co/transformers/)
* [TorchVision](https://pytorch.org/vision/stable/index.html)
* [DeepSpeed](https://www.deepspeed.ai/)
* [FlashAttention](https://github.com/HazyResearch/flash-attention)
## Intended Uses
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
This model is intended for broad research use in English. It is designed only for research purposes and aimed at knowledge-sharing and accelerating research in multimodal AI, particularly in multimodal agentic AI. It is intended to be used by domain experts who are independently capable of evaluating the quality of outputs before acting on them.
### Direct Use
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
The model takes images and text as inputs, and produces the textual outputs for the following uses:
* **Image/Video-Conditioned Text Generation:** The model can generate text (e.g., descriptions, answers) based on the input text and image.
* **Visual Planning Capabilities:** The model can also produce the visual trace as the future planning to accomplish a task (e.g., move object from one place to another).
* **Agentic Capabilities:** The model can also generate UI grounding (e.g., click ``search'' button) and robotics manipulations (e.g., 7 DoF for the robot gripper).
### Downstream Use
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
<!-- {{ downstream_use | default("[More Information Needed]", true)}} -->
<!-- ### Out-of-Scope Use -->
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
<!-- {{ out_of_scope_use | default("[More Information Needed]", true)}} -->
The model can be further finetuned for different downstream tasks, such as:
* **Image Captioning and QA:** We can further finetune this model for image captioning and QA tasks under the pipeline of multimodal LLMs. Based on our experiments, the model can achieve competitive performance yet better spatial understanding and reasoning on these tasks.
* **Video Captioning and QA:** We can further finetune this model for video captioning and QA tasks under the pipeline of multimodal LLMs. Based on our experiments, the model can achieve competitive performance yet better temporal understanding and reasoning on these tasks.
* **UI Navigation:** We can finetune this model for specific UI navigation tasks, such as web navigation or mobile navigation. The model can achieve superior performance on these tasks.
* **Robotics Manipulation:** Our model can be further finetuned for robotics tasks given its general agentic capabilities as a vision-language-action model. After finetuning, our model significantly outperforms the state-of-the-art models such as OpenVLA on robotics manipulation tasks.
## Bias, Risks, and Limitations
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
<!-- {{ bias_risks_limitations | default("[More Information Needed]", true)}} -->
Please note that this model is not specifically designed or evaluated for all downstream purposes.
The model is not intended to be deployed in production settings. It should not be used in high-risk scenarios, such as military and defense, financial services, and critical infrastructure systems.
Developers should consider common limitations of multimodal models as they select use cases, and evaluate and mitigate for accuracy, safety, and fairness before using within a specific downstream use case.
Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case. Like other multimodal models, Magma can potentially behave in ways that are unfair, unreliable, or offensive.
The models' outputs do not reflect the opinions of Microsoft.
Some of the limiting behaviors to be aware of include:
* **Quality of Service:** The model is trained primarily on English text. Languages other than English will experience worse performance. English language varieties with less representation in the training data might experience worse performance than standard American English. Magma is not intended to support multilingual use.
* **Representation of Harms & Perpetuation of Stereotypes:** These models can over- or under-represent groups of people, erase representation of some groups, or reinforce demeaning or negative stereotypes. Despite safety post-training, these limitations may still be present due to differing levels of representation of different groups or prevalence of examples of negative stereotypes in training data that reflect real-world patterns and societal biases.
* **Inappropriate or Offensive Content:** These models may produce other types of inappropriate or offensive content, which may make it inappropriate to deploy for sensitive contexts without additional mitigations that are specific to the use case.
* **Information Reliability:** Multimodal models can generate nonsensical content or fabricate content that might sound reasonable but is inaccurate or outdated.
Developers should apply responsible AI best practices and are responsible for ensuring that a specific use case complies with relevant laws and regulations (e.g. privacy, trade, etc.). Using safety services like [Azure AI Content Safety](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety) that have advanced guardrails is highly recommended.
### Recommendations
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
<!-- {{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.", true)}} -->
Magma was developed for research purposes only. Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model.
The recommended usage for the finetuned models is within the research settings they were trained on — namely,
- an android simulator running on a computer for UI manipulation.
- an enclosure equipped with a robotic arm and everyday objects for Robotic manipulation
For UI navigation task, researchers should make sure a human is in the loop and in control for every action the agentic system generates. Since the model cannot act by itself, the sub-module a researcher uses to actually perform the UI navigation action should ensure no unintended consequences can occur as a result of performing the UI action proposed by the model.
For the robotic manipulation task, some mitigation strategies to use for human safety when operating robotic arms include:
* **Safety Zones and Barriers:** Establish physical barriers or safety zones around robotic workspaces to prevent unauthorized access.
* **Emergency Stop Systems:** Equip robotic arms with easily accessible emergency stop buttons. Implement a fail-safe mechanism that triggers an immediate stop of operations in case of an emergency
* **Safety Standards and Compliance:** Adhere to established safety standards (e.g., ISO 10218, ISO/TS 15066) for industrial robots and collaborative robots.
* **User Training and Awareness:** Provide comprehensive training for all personnel working around robotic arms to understand their functions, safety features, and emergency procedures. Promote awareness of the potential risks associated with robotic manipulation.
## Citation
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
```bibtex
@misc{yang2025magmafoundationmodelmultimodal,
title={Magma: A Foundation Model for Multimodal AI Agents},
author={Jianwei Yang and Reuben Tan and Qianhui Wu and Ruijie Zheng and Baolin Peng and Yongyuan Liang and Yu Gu and Mu Cai and Seonghyeon Ye and Joel Jang and Yuquan Deng and Lars Liden and Jianfeng Gao},
year={2025},
eprint={2502.13130},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2502.13130},
}
```
<!-- {{ citation_bibtex | default("[More Information Needed]", true)}} -->
\ No newline at end of file
# 模型编码
modelCode=1521
# 模型名称
modelName=Magma_pytorch
# 模型描述
modelDescription=具身智能新时代!VLA迎来最强基础模型Magma:UI导航、机器人操作全能。
# 应用场景
appScenario=推理,具身智能,制造,家居,医疗,能源,教育
# 框架类型
frameType=pytorch
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "magma"
version = "0.0.1"
description = "A Foundation Model for Multimodal AI Agents."
readme = "README.md"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
# "torch==2.3.1",
# "torchvision==0.18.1",
"pytorch-lightning>=1.0.8",
"transformers>=4.49.0",
"tokenizers>=0.15.0",
"sentencepiece",
"shortuuid",
"accelerate==0.34.2",
"peft==0.4.0",
# "bitsandbytes==0.44.1",
"pydantic>=2.0",
"markdown2[all]",
"numpy",
"scikit-learn==1.5.0",
"gradio==4.44.1",
"gradio_client",
"spaces",
"requests",
"httpx",
"uvicorn",
"fastapi",
"einops==0.6.1",
"einops-exts==0.0.4",
"timm==0.9.12",
# "tensorflow==2.15.0",
"tensorflow_datasets==4.9.3",
"tensorflow_graphics==2021.12.3",
"draccus",
"av", # "pyav"
"numba",
# "dlimp @ git+https://github.com/moojink/dlimp_openvla",
"loguru",
"sacrebleu",
"evaluate",
"sqlitedict",
"open_clip_torch",
"flash-attn",
]
[project.optional-dependencies]
train = [
"deepspeed",
"ninja",
"wandb"
]
eval = [
"azure-ai-ml",
"datasets",
"fire",
"openai==1.8.0",
"opencv-python",
"openpyxl==3.1.2",
"pillow==9.4.0",
"python-Levenshtein",
"rich",
"streamlit==1.29.0",
"typer[all]",
"word2number",
]
agent = [
"pygame",
"easyocr",
"paddleocr",
"common==0.1.2",
"dual==0.0.10",
"tight==0.1.0",
"prox==0.0.17",
"paddle==1.0.2",
# "paddlepaddle==2.6.2",
"supervision==0.18.0",
"ultralytics==8.3.78",
]
[tool.setuptools.packages.find]
exclude = [
"assets*",
"benchmark*",
"docs",
"dist*",
"playground*",
"scripts*",
"tests*",
"azureblobs*",
"azure*"
]
[tool.wheel]
exclude = [
"assets*",
"benchmark*",
"docs",
"dist*",
"playground*",
"scripts*",
"tests*",
"azureblobs*",
"azure*"
]
[tool.black]
line-length = 120
skip-string-normalization = true
[tool.pyright]
exclude = [
"**/__pycache__",
"playground",
"_results",
"_data",
"models",
"checkpoints",
"wandb",
"docs",
]
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "magma"
version = "0.0.1"
description = "A Foundation Model for Multimodal AI Agents."
readme = "README.md"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"torch==2.3.1",
"torchvision==0.18.1",
"pytorch-lightning>=1.0.8",
"transformers>=4.49.0",
"tokenizers>=0.15.0",
"sentencepiece==0.1.99",
"shortuuid",
"accelerate==0.34.2",
"peft==0.4.0",
"bitsandbytes==0.44.1",
"pydantic>=2.0",
"markdown2[all]",
"numpy",
"scikit-learn==1.5.0",
"gradio==4.44.1",
"gradio_client",
"spaces",
"requests",
"httpx",
"uvicorn",
"fastapi",
"einops==0.6.1",
"einops-exts==0.0.4",
"timm==0.9.12",
"tensorflow==2.15.0",
"tensorflow_datasets==4.9.3",
"tensorflow_graphics==2021.12.3",
"draccus",
"pyav",
"numba",
"dlimp @ git+https://github.com/moojink/dlimp_openvla",
"loguru",
"sacrebleu",
"evaluate",
"sqlitedict",
"open_clip_torch",
"flash-attn",
]
[project.optional-dependencies]
train = [
"deepspeed",
"ninja",
"wandb"
]
eval = [
"azure-ai-ml",
"datasets",
"fire",
"openai==1.8.0",
"opencv-python",
"openpyxl==3.1.2",
"pillow==9.4.0",
"python-Levenshtein",
"rich",
"streamlit==1.29.0",
"typer[all]",
"word2number",
]
agent = [
"pygame",
"easyocr",
"paddleocr",
"common==0.1.2",
"dual==0.0.10",
"tight==0.1.0",
"prox==0.0.17",
"paddle==1.0.2",
"paddlepaddle==2.6.2",
"supervision==0.18.0",
"ultralytics==8.3.78",
]
[tool.setuptools.packages.find]
exclude = [
"assets*",
"benchmark*",
"docs",
"dist*",
"playground*",
"scripts*",
"tests*",
"azureblobs*",
"azure*"
]
[tool.wheel]
exclude = [
"assets*",
"benchmark*",
"docs",
"dist*",
"playground*",
"scripts*",
"tests*",
"azureblobs*",
"azure*"
]
[tool.black]
line-length = 120
skip-string-normalization = true
[tool.pyright]
exclude = [
"**/__pycache__",
"playground",
"_results",
"_data",
"models",
"checkpoints",
"wandb",
"docs",
]
eval_tasks=${1:-textvqa}
NUM_PROCESSES=${2:-4}
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python3 -m accelerate.commands.launch --num_processes=$NUM_PROCESSES -m lmms_eval --model magma --model_args pretrained="microsoft/Magma-8B" \
--tasks $eval_tasks --batch_size 1 --log_samples --log_samples_suffix magma_textvqa --output_path ./logs/
gpu_id=0
policy_model=magma
ckpt_path="microsoft/Magma-8B"
scene_name=bridge_table_1_v1
robot=widowx
rgb_overlay_path=ManiSkill2_real2sim/data/real_inpainting/bridge_real_eval_1.png
robot_init_x=0.147
robot_init_y=0.028
CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference_magma.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \
--robot ${robot} --policy-setup widowx_bridge \
--control-freq 5 --sim-freq 500 --max-episode-steps 60 \
--env-name PutCarrotOnPlateInScene-v0 --scene-name ${scene_name} \
--rgb-overlay-path ${rgb_overlay_path} \
--robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \
--robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 2>&1 | tee -a ./logs/Magma_PutCarrotOnPlateInScene.log;
CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference_magma.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \
--robot ${robot} --policy-setup widowx_bridge \
--control-freq 5 --sim-freq 500 --max-episode-steps 60 \
--env-name StackGreenCubeOnYellowCubeBakedTexInScene-v0 --scene-name ${scene_name} \
--rgb-overlay-path ${rgb_overlay_path} \
--robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \
--robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 2>&1 | tee -a ./logs/Magma_StackGreenCubeOnYellowCubeBakedTexInScene.log;
CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference_magma.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \
--robot ${robot} --policy-setup widowx_bridge \
--control-freq 5 --sim-freq 500 --max-episode-steps 60 \
--env-name PutSpoonOnTableClothInScene-v0 --scene-name ${scene_name} \
--rgb-overlay-path ${rgb_overlay_path} \
--robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \
--robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 2>&1 | tee -a ./logs/Magma_PutSpoonOnTableClothInScene.log;
scene_name=bridge_table_1_v2
robot=widowx_sink_camera_setup
rgb_overlay_path=ManiSkill2_real2sim/data/real_inpainting/bridge_sink.png
robot_init_x=0.127
robot_init_y=0.06
CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference_magma.py --policy-model ${policy_model} --ckpt-path ${ckpt_path} \
--robot ${robot} --policy-setup widowx_bridge \
--control-freq 5 --sim-freq 500 --max-episode-steps 120 \
--env-name PutEggplantInBasketScene-v0 --scene-name ${scene_name} \
--rgb-overlay-path ${rgb_overlay_path} \
--robot-init-x ${robot_init_x} ${robot_init_x} 1 --robot-init-y ${robot_init_y} ${robot_init_y} 1 --obj-variation-mode episode --obj-episode-range 0 24 \
--robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 0 0 1 2>&1 | tee -a ./logs/Magma_PutEggplantInBasketScene.log;
#!/bin/bash
gpu_id=0
declare -a arr=(
"microsoft/Magma-8B"
)
env_name=MoveNearGoogleBakedTexInScene-v0
# env_name=MoveNearGoogleBakedTexInScene-v1
scene_name=google_pick_coke_can_1_v4
rgb_overlay_path=./ManiSkill2_real2sim/data/real_inpainting/google_move_near_real_eval_1.png
# URDF variations
declare -a urdf_version_arr=("recolor_cabinet_visual_matching_1" "recolor_tabletop_visual_matching_1" "recolor_tabletop_visual_matching_2" None)
# Create a logs directory if it doesn't exist
mkdir -p logs
for ckpt_path in "${arr[@]}"; do
echo "Checkpoint path: $ckpt_path"
done
for urdf_version in "${urdf_version_arr[@]}"; do
for ckpt_path in "${arr[@]}"; do
# Create a unique log file name
timestamp=$(date +"%Y%m%d_%H%M%S")
log_file="logs/Magma_${env_name}_${urdf_version}_${timestamp}.txt"
echo "Starting experiment with URDF version: $urdf_version" | tee -a "$log_file"
echo "Checkpoint path: $ckpt_path" | tee -a "$log_file"
echo "Environment: $env_name" | tee -a "$log_file"
echo "Scene: $scene_name" | tee -a "$log_file"
echo "GPU ID: $gpu_id" | tee -a "$log_file"
echo "-------------------------------------------" | tee -a "$log_file"
CUDA_VISIBLE_DEVICES=${gpu_id} python simpler_env/main_inference_magma.py --policy-model magma \
--ckpt-path ${ckpt_path} \
--robot google_robot_static \
--control-freq 3 --sim-freq 513 --max-episode-steps 80 \
--env-name ${env_name} --scene-name ${scene_name} \
--rgb-overlay-path ${rgb_overlay_path} \
--robot-init-x 0.35 0.35 1 --robot-init-y 0.21 0.21 1 --obj-variation-mode episode --obj-episode-range 0 60 \
--robot-init-rot-quat-center 0 0 0 1 --robot-init-rot-rpy-range 0 0 1 0 0 1 -0.09 -0.09 1 \
--additional-env-build-kwargs urdf_version=${urdf_version} \
--additional-env-save-tags baked_except_bpb_orange 2>&1 | tee -a "$log_file"
echo "-------------------------------------------" | tee -a "$log_file"
echo "Experiment completed" | tee -a "$log_file"
echo "" | tee -a "$log_file"
done
done
echo "All experiments completed. Log files are in the 'logs' directory."
\ No newline at end of file
#!/bin/bash
# default MODEL_PATH or use the one from the environment
MODEL_PATH="microsoft/Magma-8B"
# default OUTPUT_DIR
OUTPUT_DIR="./checkpoints/finetune-magma_820k"
torchrun --nproc_per_node=4 train.py \
--deepspeed ./trainer/deepspeed/zero3.json \
--model_name_or_path $MODEL_PATH \
--version magma_instruct \
--data_path "data_configs/magma_820k.yaml" \
--vision_tower convnext_xxlarge \
--img_size 512 \
--max_num_crops 4 \
--img_anyres_strategy crop \
--vision_backbone "convnextxxlarge" \
--is_multimodal True \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--tune_vision_tokenizer 'none' \
--mm_vision_select_layer -2 \
--mm_use_image_start_end False \
--group_by_modality_length True \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 10 \
--tf32 True \
--model_max_length 4096 \
--gradient_checkpointing True \
--dataloader_num_workers 8 \
--lazy_preprocess True \
--flash_attn_2_enabled True \
--local_run False \
--show_trace False \
--run_name finetune_anyres \
--remove_static_trace_pts True
\ No newline at end of file
#!/bin/bash
# default MODEL_PATH or use the one from the environment
MODEL_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
# default OUTPUT_DIR
OUTPUT_DIR="./checkpoints/pretrain-openx"
torchrun --nproc_per_node=4 train.py \
--deepspeed ./trainer/deepspeed/zero3.json \
--model_name_or_path $MODEL_PATH \
--version magma_instruct \
--data_path "data_configs/openx.yaml" \
--vision_tower convnext_xxlarge \
--img_size 512 \
--max_num_crops 4 \
--img_anyres_strategy crop \
--vision_backbone "convnextxxlarge" \
--is_multimodal True \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--tune_vision_tokenizer 'none' \
--mm_vision_select_layer -2 \
--mm_use_image_start_end False \
--group_by_modality_length False \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 10 \
--tf32 True \
--model_max_length 4096 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--flash_attn_2_enabled True \
--local_run False \
--show_trace False \
--run_name finetune_anyres \
--remove_static_trace_pts True
\ No newline at end of file
## LMMs-Eval for MAGMA
To faciliate the quantitative evaluation of our model, we also provide a model class for [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval).
After installing lmms-eval, copy 'magma.py' to 'lmms-eval/lmms-eval/models' folder.
Remember to register our model by modifying the 'lmms-eval/lmms_eval/models/__init__.py' file as follows:
```python
AVAILABLE_MODELS = {
# many previous registered models
"magma": Magma,
}
```
\ No newline at end of file
import importlib
import os
import sys
import hf_transfer
from loguru import logger
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
logger.remove()
logger.add(sys.stdout, level="WARNING")
AVAILABLE_MODELS = {
"batch_gpt4": "BatchGPT4",
"claude": "Claude",
"cogvlm2": "CogVLM2",
"from_log": "FromLog",
"fuyu": "Fuyu",
"gemini_api": "GeminiAPI",
"gpt4v": "GPT4V",
"idefics2": "Idefics2",
"instructblip": "InstructBLIP",
"internvl": "InternVLChat",
"internvl2": "InternVL2",
"llama_vid": "LLaMAVid",
"llava": "Llava",
"llava_hf": "LlavaHf",
"llava_onevision": "Llava_OneVision",
"llava_sglang": "LlavaSglang",
"llava_vid": "LlavaVid",
"longva": "LongVA",
"mantis": "Mantis",
"minicpm_v": "MiniCPM_V",
"minimonkey": "MiniMonkey",
"mplug_owl_video": "mplug_Owl",
"phi3v": "Phi3v",
"qwen_vl": "Qwen_VL",
"qwen2_vl": "Qwen2_VL",
"qwen_vl_api": "Qwen_VL_API",
"reka": "Reka",
"srt_api": "SRT_API",
"tinyllava": "TinyLlava",
"videoChatGPT": "VideoChatGPT",
"video_llava": "VideoLLaVA",
"vila": "VILA",
"xcomposer2_4KHD": "XComposer2_4KHD",
"internvideo2": "InternVideo2",
"xcomposer2d5": "XComposer2D5",
"oryx": "Oryx",
"videochat2": "VideoChat2",
"llama_vision": "LlamaVision",
"magma": "Magma",
}
def get_model(model_name):
if model_name not in AVAILABLE_MODELS:
raise ValueError(f"Model {model_name} not found in available models.")
model_class = AVAILABLE_MODELS[model_name]
if "." not in model_class:
model_class = f"lmms_eval.models.{model_name}.{model_class}"
try:
model_module, model_class = model_class.rsplit(".", 1)
module = __import__(model_module, fromlist=[model_class])
return getattr(module, model_class)
except Exception as e:
logger.error(f"Failed to import {model_class} from {model_name}: {e}")
raise
if os.environ.get("LMMS_EVAL_PLUGINS", None):
# Allow specifying other packages to import models from
for plugin in os.environ["LMMS_EVAL_PLUGINS"].split(","):
m = importlib.import_module(f"{plugin}.models")
for model_name, model_class in getattr(m, "AVAILABLE_MODELS").items():
AVAILABLE_MODELS[model_name] = f"{plugin}.models.{model_name}.{model_class}"
import os
import uuid
import warnings
from typing import List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, DistributedType
from tqdm import tqdm
import PIL
from torchvision.transforms.functional import to_pil_image
from decord import VideoReader, cpu
import numpy as np
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.qwen.qwen_generate_utils import make_context
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")
from loguru import logger as eval_logger
from transformers import AutoModelForCausalLM, AutoProcessor
@register_model("magma")
class Magma(lmms):
"""
Magma Model
"""
def __init__(
self,
pretrained: str = "Magma/Magma-8b",
device: str = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: int = 1,
trust_remote_code: Optional[bool] = True,
attn_implementation: Optional[str] = None,
device_map: str = "",
max_frames_num: Optional[int] = 32,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
accelerator = Accelerator()
if accelerator.num_processes >= 1 and device_map == "":
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
else:
self._device = torch.device(device)
self.device_map = device_map
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)
self.dtype = torch.bfloat16
self.max_frames_num = max_frames_num
self._model = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(pretrained, trust_remote_code=trust_remote_code)
if accelerator.num_processes > 1 and device_map == "":
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism")
self._rank = 0
self._word_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._word_size = 1
self.accelerator = accelerator
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self._max_length
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
assert False, "Not implemented"
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def load_video(self, video_path, max_frames_num):
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames # (frames, height, width, channels)
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
messages = [{"role": "user", "content": []}]
images = []
for visual in visuals:
if isinstance(visual, str):
frames = self.load_video(visual, self.max_frames_num)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
images.extend([to_pil_image(frame) for frame in frames])
elif isinstance(visual, PIL.Image.Image):
images.append(visual)
for _ in range(len(images)):
messages[-1]["content"].append({"type": "image"})
messages[-1]["content"].append({"type": "text", "content": contexts})
convs = [
{"role": "user", "content": ''.join(["<image>\n"]*len(images)) + contexts},
# {"role": "user", "content": contexts},
]
convs = [
{
"role": "system",
"content": "You are agent that can see, talk and act.",
},
] + convs
prompt = self.processor.tokenizer.apply_chat_template(
convs,
tokenize=False,
add_generation_prompt=True
)
if self.model.config.mm_use_image_start_end:
prompt = prompt.replace("<image>", "<image_start><image><image_end>")
inputs = self.processor(images=images, texts=prompt, return_tensors="pt").to(self.model.device)
# convert inputs to the same data type
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
inputs = inputs.to(self.dtype)
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
if "do_sample" not in gen_kwargs:
gen_kwargs["do_sample"] = False
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=gen_kwargs["max_new_tokens"],
temperature=gen_kwargs["temperature"],
do_sample=gen_kwargs["do_sample"],
)
output = output[:, inputs["input_ids"].shape[-1] :]
if 'Phi-3-mini-128k-instruct' in self.processor.tokenizer.name_or_path:
decoded_text = self.processor.decode(output[0], skip_special_tokens=False).strip()
res.append(decoded_text.split('<|end|>')[0])
else:
res.append(self.processor.decode(output[0], skip_special_tokens=True).strip())
pbar.update(1)
pbar.close()
return res
def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation for LLaVAHF")
## SimplerEnv Evaluation for MAGMA
To faciliate the quantitative evaluation of our model, we also provide a model class for [SimplerEnv](https://github.com/simpler-env/SimplerEnv).
After installing SimplerEnv, copy './simpler_env' to the corresponding folder under 'SimplerEnv/simpler_env' folder. Make sure the hiarchy is the same as original one.
\ No newline at end of file
import os
import numpy as np
import tensorflow as tf
from simpler_env.evaluation.argparse import get_args
from simpler_env.evaluation.maniskill2_evaluator import maniskill2_evaluator
from simpler_env.policies.octo.octo_server_model import OctoServerInference
from simpler_env.policies.magma.magma_model import MagmaInference
try:
from simpler_env.policies.octo.octo_model import OctoInference
except ImportError as e:
print("Octo is not correctly imported.")
print(e)
if __name__ == "__main__":
args = get_args()
os.environ["DISPLAY"] = ""
# prevent a single jax process from taking up all the GPU memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
gpus = tf.config.list_physical_devices("GPU")
if len(gpus) > 0:
# prevent a single tf process from taking up all the GPU memory
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=args.tf_memory_limit)],
)
# policy model creation; update this if you are using a new policy model
if args.policy_model == "rt1":
assert args.ckpt_path is not None
model = RT1Inference(
saved_model_path=args.ckpt_path,
policy_setup=args.policy_setup,
action_scale=args.action_scale,
)
elif "octo" in args.policy_model:
if args.ckpt_path is None or args.ckpt_path == "None":
args.ckpt_path = args.policy_model
if "server" in args.policy_model:
model = OctoServerInference(
model_type=args.ckpt_path,
policy_setup=args.policy_setup,
action_scale=args.action_scale,
)
else:
model = OctoInference(
model_type=args.ckpt_path,
policy_setup=args.policy_setup,
init_rng=args.octo_init_rng,
action_scale=args.action_scale,
)
elif "magma" in args.policy_model:
assert args.ckpt_path is not None
model = MagmaInference(
model_name=args.ckpt_path,
policy_setup=args.policy_setup,
action_scale=args.action_scale,
# sticky_gripper_num_repeat=args.sticky_gripper_num_repeat,
# unnorm_key=args.unnorm_key,s
)
else:
raise NotImplementedError()
# run real-to-sim evaluation
success_arr = maniskill2_evaluator(model, args)
print(args)
print(" " * 10, "Average success", np.mean(success_arr))
\ No newline at end of file
import numpy as np
from PIL import Image
import random
import torch
import torchvision
import json
import sys
import os
from transformers import AutoModelForVision2Seq, AutoProcessor
from magma.image_processing_magma import MagmaImageProcessor
from magma.processing_magma import MagmaProcessor
from magma.modeling_magma import MagmaForCausalLM
from transforms3d.euler import euler2axangle
action_norm_stats = {
"bridge_orig": {'mask': [True, True, True, True, True, True, False], 'max': [0.41691166162490845, 0.25864794850349426, 0.21218234300613403, 3.122201919555664, 1.8618112802505493, 6.280478477478027, 1.0], 'mean': [0.0002334194869035855, 0.00013004911306779832, -0.00012762474943883717, -0.0001556558854645118, -0.0004039328487124294, 0.00023557482927571982, 0.5764579176902771], 'min': [-0.4007510244846344, -0.13874775171279907, -0.22553899884223938, -3.2010786533355713, -1.8618112802505493, -6.279075622558594, 0.0], 'q01': [-0.02872725307941437, -0.04170349963009357, -0.026093858778476715, -0.08092105075716972, -0.09288699507713317, -0.20718276381492615, 0.0], 'q99': [0.028309678435325586, 0.040855254605412394, 0.040161586627364146, 0.08192047759890528, 0.07792850524187081, 0.20382574498653397, 1.0], 'std': [0.009765930473804474, 0.013689135201275349, 0.012667362578213215, 0.028534092009067535, 0.030637972056865692, 0.07691419124603271, 0.4973701536655426]},
"google_robot": {'mask': [True, True, True, True, True, True, False], 'max': [2.9984593391418457, 22.09052848815918, 2.7507524490356445, 1.570636510848999, 1.5321086645126343, 1.5691522359848022, 1.0], 'mean': [0.006987582892179489, 0.006265917327255011, -0.01262515690177679, 0.04333311319351196, -0.005756212864071131, 0.0009130256366916001, 0.5354204773902893], 'min': [-2.0204520225524902, -5.497899532318115, -2.031663417816162, -1.569917917251587, -1.569892168045044, -1.570419430732727, 0.0], 'q01': [-0.22453527510166169, -0.14820013284683228, -0.231589707583189, -0.3517994859814644, -0.4193011274933815, -0.43643461108207704, 0.0], 'q99': [0.17824687153100965, 0.14938379630446405, 0.21842354819178575, 0.5892666035890578, 0.35272657424211445, 0.44796681255102094, 1.0], 'std': [0.0692116990685463, 0.05970962345600128, 0.07353084534406662, 0.15610496699810028, 0.13164450228214264, 0.14593800902366638, 0.497110515832901]}
}
class MagmaInference:
def __init__(self, model_name, policy_setup, action_scale=1.0, sticky_gripper_num_repeat=10, unnorm_key=None, sample=False):
if policy_setup == "widowx_bridge":
self.unnorm_key = "bridge_orig" if unnorm_key is None else unnorm_key
elif policy_setup == "google_robot":
self.unnorm_key = "fractal20220817_data" if unnorm_key is None else unnorm_key
self.sticky_gripper_num_repeat = sticky_gripper_num_repeat
self.processor = MagmaProcessor.from_pretrained(model_name, trust_remote_code=True)
self.vla = MagmaForCausalLM.from_pretrained(
model_name,
device_map="cuda",
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.task_description = None
self.policy_setup = policy_setup
self.action_scale = action_scale
self.sample = sample
self.sticky_action_is_on = False
self.gripper_action_repeat = 0
self.sticky_gripper_action = 0.0
self.previous_gripper_action = None
self.action_norm_stats = action_norm_stats[self.unnorm_key]
self.n_action_bins = 256
self.vocab_size = self.processor.tokenizer.vocab_size
self.bins = np.linspace(-1, 1, self.n_action_bins)
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
def reset(self, task_description):
self.task_description = task_description
def step(self, image: np.ndarray, task_description: str | None = None):
if task_description is not None and task_description != self.task_description:
self.reset(task_description)
convs = [
{"role": "user", "content": f"<image>\nWhat action should the robot take to {self.task_description}?"},
]
convs = [
{
"role": "system",
"content": "You are agent that can see, talk and act.",
},
] + convs
prompt = self.processor.tokenizer.apply_chat_template(
convs,
tokenize=False,
add_generation_prompt=True
)
if self.vla.config.mm_use_image_start_end:
prompt = prompt.replace("<image>", "<image_start><image><image_end>")
image = Image.fromarray(image)
# resize image to 256x256
# image = image.resize((512, 512))
image = image.resize((256, 256))
inputs = self.processor(images=image, texts=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
inputs = inputs.to("cuda").to(torch.bfloat16)
self.vla.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
with torch.inference_mode():
output_ids = self.vla.generate(
**inputs,
temperature=0.7,
do_sample=self.sample,
num_beams=1,
max_new_tokens=1000,
use_cache=True,
)
action_ids = output_ids[0, -8:-1].cpu().tolist()
if random.random() < 0.1:
print("Action ids", action_ids)
predicted_action_ids = np.array(action_ids).astype(np.int64)
discretized_actions = self.vocab_size - predicted_action_ids
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
normalized_actions = self.bin_centers[discretized_actions]
# Unnormalize actions
mask = self.action_norm_stats.get("mask", np.ones_like(self.action_norm_stats["q01"], dtype=bool))
action_high, action_low = np.array(self.action_norm_stats["q99"]), np.array(self.action_norm_stats["q01"])
raw_actions = np.where(
mask,
0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
normalized_actions,
)
raw_action = {
"world_vector": np.array(raw_actions[:3]),
"rotation_delta": np.array(raw_actions[3:6]),
"open_gripper": np.array(raw_actions[6:7]), # range [0, 1]; 1 = open; 0 = close
}
# print(raw_action)
# Process raw_action to obtain the action for the maniskill2 environment
action = {}
action["world_vector"] = raw_action["world_vector"] * self.action_scale
action_rotation_delta = np.asarray(raw_action["rotation_delta"], dtype=np.float64)
roll, pitch, yaw = action_rotation_delta
action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw)
action_rotation_axangle = action_rotation_ax * action_rotation_angle
action["rot_axangle"] = action_rotation_axangle * self.action_scale
if self.policy_setup == "google_robot":
current_gripper_action = raw_action["open_gripper"]
if self.previous_gripper_action is None:
relative_gripper_action = np.array([0])
else:
relative_gripper_action = self.previous_gripper_action - current_gripper_action
self.previous_gripper_action = current_gripper_action
if np.abs(relative_gripper_action) > 0.5 and not self.sticky_action_is_on:
self.sticky_action_is_on = True
self.sticky_gripper_action = relative_gripper_action
if self.sticky_action_is_on:
self.gripper_action_repeat += 1
relative_gripper_action = self.sticky_gripper_action
if self.gripper_action_repeat == self.sticky_gripper_num_repeat:
self.sticky_action_is_on = False
self.gripper_action_repeat = 0
self.sticky_gripper_action = 0.0
action["gripper"] = relative_gripper_action
elif self.policy_setup == "widowx_bridge":
action["gripper"] = 2.0 * (raw_action["open_gripper"] > 0.5) - 1.0
action["terminate_episode"] = np.array([0.0])
return raw_action, action
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment