which collect the missing and unexpected keys, respectively.
):
"""
"""Hook to ignore missing keys during checkpoint loading.
forparam_nameinparam_names:
ifparam_nameinincompatible_keys.missing_keys:
By default, this should not be used to avoid accidentally missing weights in checkpoint loading.
logging.getLogger(__name__).warning(
f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel"
Example use case: Use this if you want to load a checkpoint that contains vision and language
)
model weights but not the vision projection weights.
incompatible_keys.missing_keys.remove(param_name)
Args:
param_names (list str): Parameter names allowed to be missing when calling load_state_dict.
# pylint: disable-next=line-too-long
module (torch.nn.Module): The torch module this hook applies to. Required by the torch API.
# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218
incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys,
# Copyright (c) 2023 OpenGVLab.
which collect the missing and unexpected keys, respectively.
defpixel_shuffle(x,scale_factor=0.5,version=2):
"""
"""Pixel shuffle based on InternVL but adapted for our use case.
forparam_nameinparam_names:
ifparam_nameinincompatible_keys.missing_keys:
Args:
logging.getLogger(__name__).warning(
x (torch.Tensor): Vision model outputs [num_tiles, img_seq_len, h_vision]
f"{param_name} being removed from incompatible_keys.missing_keys in LlavaModel"
f"_extra_state key {key} being removed from {name}"
)
returnx
keys.remove(key)
# pylint: disable-next=line-too-long
# Based on https://github.com/OpenGVLab/InternVL/blob/c7c5af1a8930b4862afe8ed14672307082ef61fa/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L218
# Copyright (c) 2023 OpenGVLab.
defpixel_shuffle(x,scale_factor=0.5,version=2):
"""Pixel shuffle based on InternVL but adapted for our use case.
Args:
x (torch.Tensor): Vision model outputs [num_tiles, img_seq_len, h_vision]