import copy import os from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union import sentencepiece as spm from transformers import ( TOKENIZER_MAPPING, LlamaConfig, Phi3Config, PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, Qwen2Config, ) from sglang.utils import logger # Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21 VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"} PRETRAINED_VOCAB_FILES_MAP = {} # Modified from transformers.model.llama.configuration_llama.LlamaConfig class InternLM2Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`InternLM2Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 11008): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings Example: """ model_type = "internlm2" _auto_class = "AutoConfig" def __init__( # pylint: disable=W0102 self, vocab_size=103168, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, bias=True, rope_theta=10000, rope_scaling=None, attn_implementation="eager", **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.bias = bias if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() self.attn_implementation = attn_implementation if self.attn_implementation is None: self.attn_implementation = "eager" super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ if self.rope_scaling is None: return if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " f"got {self.rope_scaling}" ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if ( rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0 ): raise ValueError( f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}" ) class InternVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to instantiate a vision encoder according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: num_channels (`int`, *optional*, defaults to 3): Number of color channels in the input images (e.g., 3 for RGB). patch_size (`int`, *optional*, defaults to 14): The size (resolution) of each patch. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. qkv_bias (`bool`, *optional*, defaults to `False`): Whether to add a bias to the queries and values in the self-attention layers. hidden_size (`int`, *optional*, defaults to 3200): Dimensionality of the encoder layers and the pooler layer. num_attention_heads (`int`, *optional*, defaults to 25): Number of attention heads for each attention layer in the Transformer encoder. intermediate_size (`int`, *optional*, defaults to 12800): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. qk_normalization (`bool`, *optional*, defaults to `True`): Whether to normalize the queries and keys in the self-attention layers. num_hidden_layers (`int`, *optional*, defaults to 48): Number of hidden layers in the Transformer encoder. use_flash_attn (`bool`, *optional*, defaults to `True`): Whether to use flash attention mechanism. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-6): The epsilon used by the layer normalization layers. dropout (`float`, *optional*, defaults to 0.0): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. drop_path_rate (`float`, *optional*, defaults to 0.0): Dropout rate for stochastic depth. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. initializer_factor (`float`, *optional*, defaults to 0.1): A factor for layer scale. """ model_type = "intern_vit_6b" def __init__( self, num_channels=3, patch_size=14, image_size=224, qkv_bias=False, hidden_size=3200, num_attention_heads=25, intermediate_size=12800, qk_normalization=True, num_hidden_layers=48, use_flash_attn=True, hidden_act="gelu", layer_norm_eps=1e-6, dropout=0.0, drop_path_rate=0.0, attention_dropout=0.0, initializer_range=0.02, initializer_factor=0.1, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.dropout = dropout self.drop_path_rate = drop_path_rate self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.qkv_bias = qkv_bias self.qk_normalization = qk_normalization self.use_flash_attn = use_flash_attn @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict( pretrained_model_name_or_path, **kwargs ) if "vision_config" in config_dict: config_dict = config_dict["vision_config"] if ( "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type ): logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) class InternVLChatConfig(PretrainedConfig): model_type = "internvl_chat" is_composition = True def __init__( self, vision_config=None, llm_config=None, use_backbone_lora=0, use_llm_lora=0, pad2square=False, select_layer=-1, force_image_size=None, downsample_ratio=0.5, template=None, dynamic_image_size=False, use_thumbnail=False, ps_version="v1", min_dynamic_patch=1, max_dynamic_patch=6, **kwargs, ): super().__init__(**kwargs) if vision_config is None: vision_config = {"architectures": ["InternVisionModel"]} logger.info( "vision_config is None. Initializing the InternVisionConfig with default values." ) if llm_config is None: # TODO: There might still be a bug in transformers version 4.44 and above. llm_config = {"architectures": [""]} logger.info( "llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)." ) self.vision_config = InternVisionConfig(**vision_config) if llm_config["architectures"][0] == "LlamaForCausalLM": self.llm_config = LlamaConfig(**llm_config) elif llm_config["architectures"][0] == "InternLM2ForCausalLM": self.llm_config = InternLM2Config(**llm_config) elif llm_config["architectures"][0] == "Phi3ForCausalLM": self.llm_config = Phi3Config(**llm_config) elif llm_config["architectures"][0] == "Qwen2ForCausalLM": self.llm_config = Qwen2Config(**llm_config) else: raise ValueError( "Unsupported architecture: {}".format(llm_config["architectures"][0]) ) self.use_backbone_lora = use_backbone_lora self.use_llm_lora = use_llm_lora self.pad2square = pad2square self.select_layer = select_layer self.force_image_size = force_image_size self.downsample_ratio = downsample_ratio self.template = template self.dynamic_image_size = dynamic_image_size self.use_thumbnail = use_thumbnail self.ps_version = ps_version # pixel shuffle version self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch self.hidden_size = self.llm_config.hidden_size # By default, we use tie_word_embeddings=False for models of all sizes. self.tie_word_embeddings = False self.llm_config.tie_word_embeddings = self.tie_word_embeddings def to_dict(self): """ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, """ output = copy.deepcopy(self.__dict__) output["vision_config"] = self.vision_config.to_dict() output["llm_config"] = self.llm_config.to_dict() output["model_type"] = self.__class__.model_type output["use_backbone_lora"] = self.use_backbone_lora output["use_llm_lora"] = self.use_llm_lora output["select_layer"] = self.select_layer output["force_image_size"] = self.force_image_size output["downsample_ratio"] = self.downsample_ratio output["template"] = self.template output["dynamic_image_size"] = self.dynamic_image_size output["use_thumbnail"] = self.use_thumbnail output["ps_version"] = self.ps_version output["min_dynamic_patch"] = self.min_dynamic_patch output["max_dynamic_patch"] = self.max_dynamic_patch return output # # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast # class InternLM2TokenizerFast(PreTrainedTokenizerFast): # vocab_files_names = VOCAB_FILES_NAMES # slow_tokenizer_class = InternLM2Tokenizer # padding_side = 'left' # model_input_names = ['input_ids', 'attention_mask'] # _auto_class = 'AutoTokenizer' # # def __init__( # self, # vocab_file, # unk_token='', # bos_token='', # eos_token='', # pad_token='', # sp_model_kwargs: Optional[Dict[str, Any]] = None, # add_bos_token=True, # add_eos_token=False, # decode_with_prefix_space=False, # clean_up_tokenization_spaces=False, # **kwargs, # ): # super().__init__( # vocab_file=vocab_file, # unk_token=unk_token, # bos_token=bos_token, # eos_token=eos_token, # pad_token=pad_token, # sp_model_kwargs=sp_model_kwargs, # add_bos_token=add_bos_token, # add_eos_token=add_eos_token, # decode_with_prefix_space=decode_with_prefix_space, # clean_up_tokenization_spaces=clean_up_tokenization_spaces, # **kwargs, # ) # self._add_bos_token = add_bos_token # self._add_eos_token = add_eos_token # self.update_post_processor() # self.vocab_file = vocab_file # # @property # def can_save_slow_tokenizer(self) -> bool: # return os.path.isfile(self.vocab_file) if self.vocab_file else False # # def update_post_processor(self): # """ # Updates the underlying post processor with the current `bos_token` and `eos_token`. # """ # bos = self.bos_token # bos_token_id = self.bos_token_id # if bos is None and self.add_bos_token: # raise ValueError('add_bos_token = True but bos_token = None') # # eos = self.eos_token # eos_token_id = self.eos_token_id # if eos is None and self.add_eos_token: # raise ValueError('add_eos_token = True but eos_token = None') # # single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}" # pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}" # # special_tokens = [] # if self.add_bos_token: # special_tokens.append((bos, bos_token_id)) # if self.add_eos_token: # special_tokens.append((eos, eos_token_id)) # self._tokenizer.post_processor = processors.TemplateProcessing( # single=single, pair=pair, special_tokens=special_tokens # ) # # @property # def add_eos_token(self): # return self._add_eos_token # # @property # def add_bos_token(self): # return self._add_bos_token # # @add_eos_token.setter # def add_eos_token(self, value): # self._add_eos_token = value # self.update_post_processor() # # @add_bos_token.setter # def add_bos_token(self, value): # self._add_bos_token = value # self.update_post_processor() # # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: # if not self.can_save_slow_tokenizer: # raise ValueError( # 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' # 'tokenizer.' # ) # # if not os.path.isdir(save_directory): # logger.error(f'Vocabulary path ({save_directory}) should be a directory') # return # out_vocab_file = os.path.join( # save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] # ) # # if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): # copyfile(self.vocab_file, out_vocab_file) # # return (out_vocab_file,) # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer class InternLM2Tokenizer(PreTrainedTokenizer): """ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding. Args: vocab_file (`str`): Path to the vocabulary file. """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP model_input_names = ["input_ids", "attention_mask"] _auto_class = "AutoTokenizer" def __init__( self, vocab_file, unk_token="", bos_token="", eos_token="", pad_token="", sp_model_kwargs: Optional[Dict[str, Any]] = None, add_bos_token=True, add_eos_token=False, decode_with_prefix_space=False, clean_up_tokenization_spaces=False, **kwargs, ): print("register succeed") self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs self.vocab_file = vocab_file self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.decode_with_prefix_space = decode_with_prefix_space self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(vocab_file) self._no_prefix_space_tokens = None super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs, ) @property def no_prefix_space_tokens(self): if self._no_prefix_space_tokens is None: vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) self._no_prefix_space_tokens = { i for i, tok in enumerate(vocab) if not tok.startswith("▁") } return self._no_prefix_space_tokens @property def vocab_size(self): """Returns vocab size""" return self.sp_model.get_piece_size() @property def bos_token_id(self) -> Optional[int]: return self.sp_model.bos_id() @property def eos_token_id(self) -> Optional[int]: return self.sp_model.eos_id() def get_vocab(self): """Returns vocab as a dict""" vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text): """Returns a tokenized string.""" return self.sp_model.encode(text, out_type=str) def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.sp_model.piece_to_id(token) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) return token def _maybe_add_prefix_space(self, tokens, decoded): if tokens and tokens[0] not in self.no_prefix_space_tokens: return " " + decoded else: return decoded def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" current_sub_tokens = [] out_string = "" prev_is_special = False for token in tokens: # make sure that special tokens are not decoded using sentencepiece model if token in self.all_special_tokens: if not prev_is_special: out_string += " " out_string += self.sp_model.decode(current_sub_tokens) + token prev_is_special = True current_sub_tokens = [] else: current_sub_tokens.append(token) prev_is_special = False out_string += self.sp_model.decode(current_sub_tokens) out_string = self.clean_up_tokenization(out_string) out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) return out_string[1:] def save_vocabulary( self, save_directory, filename_prefix: Optional[str] = None ) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. Args: save_directory (`str`): The directory in which to save the vocabulary. Returns: `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"], ) if os.path.abspath(self.vocab_file) != os.path.abspath( out_vocab_file ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: content_spiece_model = self.sp_model.serialized_model_proto() fi.write(content_spiece_model) return (out_vocab_file,) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): if self.add_bos_token: bos_token_ids = [self.bos_token_id] else: bos_token_ids = [] output = bos_token_ids + token_ids_0 if token_ids_1 is not None: output = output + token_ids_1 if self.add_eos_token: output = output + [self.eos_token_id] return output def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True, ) if token_ids_1 is None: return [1] + ([0] * len(token_ids_0)) + [1] return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make use of token type ids, therefore a list of zeros is returned. Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of zeros. """ eos = [self.eos_token_id] if token_ids_1 is None: return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0] TOKENIZER_MAPPING.register( InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True )