"awq/vscode:/vscode.git/clone" did not exist on "18712d00d7a420c0c8e774cff96131d151e6878c"
model.py 12.3 KB
Newer Older
silencealiang's avatar
silencealiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
import warnings
from copy import deepcopy

import torch
from config import get_language_model_config, get_vision_model_config, get_vision_projection_config
from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec, get_norm_mlp_module_spec_te

from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN, LLaVAModel
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args


def model_provider(
    pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
) -> LLaVAModel:
    """Builds the model.

    Args:
        pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True.
        post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True.
        add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder
            will live on only a subset of the pipeline stages (specifically, only the first stage).
        add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder
            will live on only a subset of the pipeline stages (specifically, every stage after the first one).
        parallel_output (bool): Enable parallel model output.

    Returns:
        model: A multimodal model.
    """
    args = get_args()
    assert args.encoder_pipeline_model_parallel_size <= 1, "LLaVA does not support pp>1 for encoder on it's own pipeline rank"

    use_te = args.use_te

    print_rank_0('building a multimodal model ...')

    num_image_embeddings = get_num_image_embeddings(
        args.img_h,
        args.img_w,
        args.patch_dim,
        args.vision_model_type,
        args.disable_vision_class_token,
        1,
        args.pixel_shuffle,
        args.use_tile_tags,
    )
    old_seq_length = args.seq_length
    args.seq_length = args.encoder_seq_length = num_image_embeddings
    if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length:
        warnings.warn(
            f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
        )

    max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings

    assert (
        args.decoder_seq_length is not None
    ), "Please provide --decoder-seq-length to set the language model sequence length"
    assert (
        args.decoder_seq_length > max_num_image_embeddings
    ), "Language model sequence length must be greater than the maximum number of image embeddings"
    if args.decoder_seq_length > args.max_position_embeddings:
        args.max_position_embeddings = args.decoder_seq_length
        warnings.warn(
            f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length"
        )

    base_config = core_transformer_config_from_args(get_args())
    base_config.language_model_type = args.language_model_type
    base_config.vision_model_type = args.vision_model_type
    base_config.calculate_per_token_loss = True

    language_config = deepcopy(base_config)
    language_config = get_language_model_config(language_config)

    if use_te:
        # Padding mask needed for SP/CP.
        padding = args.context_parallel_size > 1 and args.sequence_parallel
        language_transformer_layer_spec = get_layer_spec_te(
            is_vit=False, padding=padding
        )  # TENorm detects LayerNorm/RMS automatically.
    else:
        language_transformer_layer_spec = get_layer_spec(
            is_vit=False, normalization=language_config.normalization
        )

    vision_model_type = args.vision_model_type
    vision_config = deepcopy(base_config)
    vision_config = get_vision_model_config(
        vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling
    )
    if vision_model_type.startswith("huggingface"):
        assert args.encoder_tensor_model_parallel_size < 2, "Huggingface vision encoders do not support --encoder-tensor-model-parallel-size > 1"
        assert args.encoder_pipeline_model_parallel_size == 0, "Huggingface vision encoders do not support --encoder-pipeline-model-parallel-size > 0"
        assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
        assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
        assert args.vision_huggingface_model_name_or_path is not None, "Providing --vision-huggingface-model-name-or-path is necessary when using huggingface vision model"

        vision_config.huggingface_model_name_or_path = args.vision_huggingface_model_name_or_path

        from transformers import AutoConfig
        huggingface_config = AutoConfig.from_pretrained(vision_config.huggingface_model_name_or_path)
        vision_config.hidden_size = huggingface_config.hidden_size

    vision_model_type = args.vision_model_type
    if vision_model_type in ["clip", "siglip", "radio"]:
        if use_te:
            vision_transformer_layer_spec = get_layer_spec_te(
                is_vit=True
            )  # TENorm detects LayerNorm/RMS automatically.
        else:
            vision_transformer_layer_spec = get_layer_spec(
                is_vit=True, normalization=vision_config.normalization
            )
    elif vision_model_type == "internvit":
        from nvlm.internvit import get_internvit_layer_spec
        vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te)
    elif vision_model_type.startswith("huggingface"):
        vision_transformer_layer_spec = None
    else:
        raise RuntimeError("unsupported vision model type", vision_model_type)

    vision_projection_config = deepcopy(base_config)

    if base_config.language_model_type.startswith("huggingface"):
        assert args.tensor_model_parallel_size == 1, "Huggingface models do not support --tensor-model-parallel-size > 1"
        assert args.pipeline_model_parallel_size < 2, "Huggingface models do not support --pipeline-model-parallel-size > 1"
        assert not args.sequence_parallel, "Huggingface models do not support --sequence-parallel"
        assert args.context_parallel_size < 2, "Huggingface models do not support --context-parallel-size > 1"
        assert args.language_huggingface_model_name_or_path is not None, "Providing --language-huggingface-model-name-or-path is necessary when using huggingface language model"

        language_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path
        # Pass to vision projection config so can choose the correct ffn hidden size
        vision_projection_config.huggingface_model_name_or_path = args.language_huggingface_model_name_or_path

    vision_projection_config = get_vision_projection_config(
        vision_projection_config, language_config.hidden_size
    )

    # --encoder-pipeline-model-parallel-size 1 will enable a separate pipeline stage for the vision model.
    if args.encoder_pipeline_model_parallel_size > 0:
        assert (
            args.encoder_pipeline_model_parallel_size == 1
        ), "vision model and projection can only live on 1 pipeline stage."

        if args.encoder_tensor_model_parallel_size > 0:
            vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
            vision_projection_config.tensor_model_parallel_size = (
                args.encoder_tensor_model_parallel_size
            )

    # Make sure vision model pipeline parallel size is not inherited from the language model pipeline parallel size.
    # 0 is not a valid for the config value, hence max(1, ).
    vision_config.pipeline_model_parallel_size = max(1, args.encoder_pipeline_model_parallel_size)
    vision_projection_config.pipeline_model_parallel_size = vision_config.pipeline_model_parallel_size

    # Make sure the vision model does not inherit first and last pipeline num layers from the language model.
    vision_config.first_pipeline_num_layers = vision_config.last_pipeline_num_layers = None

    if vision_projection_config.normalization:
        vision_projection_layer_spec = get_norm_mlp_module_spec_te().submodules
    else:
        vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules

    # Toggle --recompute* for the vision and language model separately.
    if args.recompute_vision:
        if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None:
            vision_config.recompute_num_layers = vision_config.num_layers
    else:
        vision_config.recompute_granularity = None
        vision_config.recompute_method = None
        vision_config.recompute_num_layers = None

    vision_projection_config.recompute_granularity = None
    vision_projection_config.recompute_method = None
    vision_projection_config.recompute_num_layers = None

    # TODO: Vision model and projection do not use SP/CP yet.
    vision_config.sequence_parallel = False
    vision_config.context_parallel_size = 1
    vision_config.tp_comm_overlap = False

    vision_projection_config.sequence_parallel = False
    vision_projection_config.context_parallel_size = 1
    vision_projection_config.tp_comm_overlap = False

    tokenizer = get_tokenizer()
    image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
    assert image_token_index is not None, f"IMAGE_TOKEN={IMAGE_TOKEN} needs to be added using the --special-tokens arg."

    tile_tags = _get_tile_tags(args, tokenizer)

    model = LLaVAModel(
        language_transformer_config=language_config,
        language_transformer_layer_spec=language_transformer_layer_spec,
        language_vocab_size=args.padded_vocab_size,
        language_max_sequence_length=args.decoder_seq_length,
        vision_transformer_config=vision_config,
        vision_transformer_layer_spec=vision_transformer_layer_spec,
        drop_vision_class_token=args.disable_vision_class_token,
        vision_projection_config=vision_projection_config,
        vision_projection_layer_spec=vision_projection_layer_spec,
        vision_projection_type="mlp",
        allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint,
        parallel_output=parallel_output,
        share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
        language_position_embedding_type=args.position_embedding_type,
        language_rotary_percent=args.rotary_percent,
        pre_process=pre_process,
        post_process=post_process,
        add_encoder=add_encoder,
        add_decoder=add_decoder,
        img_h=args.img_h,
        img_w=args.img_w,
        patch_dim=args.patch_dim,
        language_rotary_base=args.rotary_base,
        language_rope_scaling=args.use_rope_scaling,
        image_token_index=image_token_index,
        pixel_shuffle=args.pixel_shuffle,
        tile_tags=tile_tags,
    )

    model.freeze(
        freeze_language_model=args.freeze_LM,
        freeze_vision_model=args.freeze_ViT,
        freeze_vision_projection=False,
    )

    return model


def _get_tile_tags(args, tokenizer):
    """Tile tags are used in NVLM to surround image tiles with text tags."""
    if not args.use_tile_tags:
        return None

    # We expect the tokenized length of the tags is same.
    thumbnail_tag_text = "<tile_global_thumbnail>"
    if args.tokenizer_prompt_format == "nvlm-yi-34b":
        thumbnail_tag_text = "<tile_global>"

    assert args.max_num_tiles <= 6, "Up to 6 tile tags used"
    tile_tags_text = [f"<tile_{i}>" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text]

    start_idx = 0
    if tokenizer._prompt_config.has_bos:
        start_idx = 1

    # Convert to tokens [num_tiles, tile_seq_len].
    tile_tags = [tokenizer.tokenize(t)[start_idx:] for t in tile_tags_text]

    return tile_tags