"vscode:/vscode.git/clone" did not exist on "270194c4690f6adeefe874692f139bf2949450ce"
factory.py 11.2 KB
Newer Older
yangzhong's avatar
v1.0  
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import os
from typing import Optional

import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, Lambda
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPVisionModel, CLIPImageProcessor, AutoModel, AutoProcessor
import open_clip

from .xgenmm import XGenMMPerceiver
from .utils import hasattr_recursive, setattr_recursive

from PIL import Image
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


MODEL_FAMILY_TO_CLASS = {
    "xgenmm_v1": XGenMMPerceiver,
}
SUPPORTED_MODEL_FAMILIES = MODEL_FAMILY_TO_CLASS.keys()

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def create_model_and_transforms(
    clip_vision_encoder_path: str,
    clip_vision_encoder_pretrained: str,
    lang_model_path: str,
    tokenizer_path: str,
    model_family: str = "flamingo",
    pretrained_vision_tokenizer: Optional[str] = None,
    use_local_files: bool = False,
    decoder_layers_attr_name: str = None,
    cache_dir: Optional[str] = None,
    gradient_checkpointing: bool = False,
    verbose: bool = True,
    **model_kwargs,
):
    """
    Initialize a Flamingo model from a pretrained vision encoder and language encoder.
    Appends special tokens to the tokenizer and freezes backbones.

    Args:
        clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
        clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
        lang_model_path (str): path to pretrained language encoder
        tokenizer_path (str): path to pretrained tokenizer
        cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
        use_local_files (bool, optional): whether to use local files. Defaults to False.
        decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
        cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
        gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
        verbose (bool, optional): whether to print model info. Defaults to True.
    Returns:
        Flamingo: Flamingo model from pretrained vision and language encoders
        Image processor: Pipeline to preprocess input images
        Tokenizer: A tokenizer for the language model
    """

    assert model_family in SUPPORTED_MODEL_FAMILIES

    # load vision encoder
    if clip_vision_encoder_pretrained == 'openai':
        vision_encoder = CLIPVisionModel.from_pretrained(clip_vision_encoder_path)
        hf_processor = CLIPImageProcessor.from_pretrained(clip_vision_encoder_path)
        n_px = hf_processor.crop_size['height']
        # Use torchvision processor to be consistent with other vision encoders.
        # https://github.com/openai/CLIP/blob/main/clip/clip.py
        image_processor = Compose([
                                Resize((n_px, n_px), interpolation=BICUBIC),
                                CenterCrop(n_px),
                                _convert_image_to_rgb,
                                ToTensor(),
                                Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
                            ])
        vis_hidden_dim = vision_encoder.config.hidden_size
    elif clip_vision_encoder_pretrained == 'google':
        # "google/siglip-so400m-patch14-384"
        model = AutoModel.from_pretrained(clip_vision_encoder_path)
        hf_processor = AutoProcessor.from_pretrained(clip_vision_encoder_path)
        n_px = hf_processor.image_processor.size['height']
        vision_encoder = model.vision_model
        vis_hidden_dim = vision_encoder.config.hidden_size
        
        # Define the transformation sequence
        image_processor = Compose([
            Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True),
            Lambda(lambda x: x.convert('RGB') if x.mode != 'RGB' else x),
            ToTensor(),
            Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    else:
        vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
            clip_vision_encoder_path,
            pretrained=clip_vision_encoder_pretrained,
        )
        vision_encoder.visual.output_tokens = True
        vision_encoder = vision_encoder.visual
        vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
        if "SigLIP" in clip_vision_encoder_path or "EVA" in clip_vision_encoder_path: # SigLIP models have a different config format
            vis_hidden_dim = vision_encoder_config["embed_dim"]
        else:    
            vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]

    # load tokenizer and ensure there is a pad token
    text_tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        local_files_only=use_local_files,
        trust_remote_code=True,
        use_fast=False, 
    )
    if text_tokenizer.pad_token is None or text_tokenizer.pad_token == text_tokenizer.eos_token:
        # add a pad token if it doesn't exist
        text_tokenizer.add_special_tokens({"pad_token": "<pad>"})
        added_pad_token = True
    else:
        added_pad_token = False

    # load langauge model
    if ('phi3' in lang_model_path.lower()) or ('phi-3' in lang_model_path.lower()):
        if 'instruct' not in lang_model_path.lower():
            raise ValueError("As of now, we only support instruct models for phi3. Please use a model with 'instruct' in the path.")
        trust_remote_code_flag = True # phi3 is not stable yet, so we trust the remote code
    else:
        trust_remote_code_flag = False # froce to use modeling code from local files so that the fsdp wrapper can be applied
    lang_model = AutoModelForCausalLM.from_pretrained(
        lang_model_path,
        local_files_only=use_local_files,
        trust_remote_code=trust_remote_code_flag,
    )    
        
    check_embedding_fns(lang_model)

    # init the model
    if decoder_layers_attr_name is None:
        decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model)

    model = MODEL_FAMILY_TO_CLASS[model_family](
        vision_encoder=vision_encoder,
        lang_model=lang_model,
        vis_feature_dim=vis_hidden_dim,
        initial_tokenizer_len=len(text_tokenizer),
        gradient_checkpointing=gradient_checkpointing,
        decoder_layers_attr_name=decoder_layers_attr_name,
        pad_token_id=text_tokenizer.pad_token_id,
        **model_kwargs,
    )
    if pretrained_vision_tokenizer is not None:
        assert os.path.exists(pretrained_vision_tokenizer), "pretrained weight must exist."
        vis_tok_weight = torch.load(pretrained_vision_tokenizer)
        model.vision_tokenizer.load_state_dict(vis_tok_weight, strict=True)

    # add special tokens to the tokenizer and language models
    text_tokenizer.add_special_tokens(
        {"additional_special_tokens": list(model.special_tokens.values())}
    )
    model.lang_model.config.vocab_size = len(text_tokenizer)
    model.set_special_token_ids(
        {
            v: text_tokenizer.convert_tokens_to_ids(v)
            for v in model.special_tokens.values()
        }
    )
    # freeze appropriate parameters
    model.set_trainable()

    # log model info
    if verbose:
        print(
            f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
        )
        print(f"==========Trainable Parameters\n{model.num_trainable_params_per_module}")
        print(f"==========Total Parameters\n{model.num_params_per_module}\n==========")
    return model, image_processor, text_tokenizer


def _infer_decoder_layers_attr_name(model):
    """
    Infer the name of the attribute storing the decoder layers (as a ModuleList) in the model.
    """
    for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
        if k.lower() in model.__class__.__name__.lower():
            return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

    raise ValueError(
        f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
    )


__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
    "opt": "model.decoder.layers",
    "gptj": "transformer.h",
    "gpt-j": "transformer.h",
    "pythia": "gpt_neox.layers",
    "llama": "model.layers",
    "gptneoxforcausallm": "gpt_neox.layers",
    "mpt": "transformer.blocks",
    "mosaicgpt": "transformer.blocks",
    "gemma": "model.layers",
    "phi": "model.layers",
    "minicpm": "model.layers",
    "stablelm": "model.layers",
    "qwen": "model.layers",
    "mistral": "model.layers"
}


def check_embedding_fns(lang_model):
    """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
    if not has_fn(lang_model, "get_input_embeddings"):
        if hasattr_recursive(lang_model, "transformer.wte"):  # MPT
            lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
        elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"):  # OPT
            lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
        else:
            raise ValueError(
                "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
            )

    if not has_fn(lang_model, "set_input_embeddings"):
        if hasattr_recursive(lang_model, "transformer.wte"):  # MPT
            lang_model.set_input_embeddings = lambda x: setattr_recursive(
                lang_model, "transformer.wte", x
            )
        elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"):  # OPT
            lang_model.set_input_embeddings = lambda x: setattr_recursive(
                lang_model, "model.decoder.embed_tokens", x
            )
        else:
            raise ValueError(
                "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
            )

    if not has_fn(lang_model, "get_output_embeddings"):
        if hasattr_recursive(lang_model, "lm_head"):
            lang_model.get_output_embeddings = lambda: lang_model.lm_head
        else:
            raise ValueError(
                "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
            )

    if not has_fn(lang_model, "set_output_embeddings"):
        if hasattr_recursive(lang_model, "lm_head"):
            lang_model.set_output_embeddings = lambda x: setattr_recursive(
                lang_model, "lm_head", x
            )
        else:
            raise ValueError(
                "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
            )


def has_fn(model, fn_name):
    """Check if model has a function fn_name"""
    return callable(getattr(model, fn_name, None))