config.py 3.36 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
import torch

from megatron.training.activations import quick_gelu, squared_relu


def get_language_model_config(config):
    if config.language_model_type == "2b":
        config.add_bias_linear = False
        config.bias_activation_fusion = False
        config.gated_linear_unit = True
        config.apply_query_key_layer_scaling = True
        config.layernorm_zero_centered_gamma = True
        config.bias_dropout_fusion = False
        config.rotary_percent = 0.5
        config.apply_rope_fusion = False
        config.attention_softmax_in_fp32 = True
    elif config.language_model_type == "8b":
        config.add_bias_linear = False
        config.bias_activation_fusion = False
        config.gated_linear_unit = False
        config.apply_query_key_layer_scaling = True
        config.layernorm_zero_centered_gamma = True
        config.bias_dropout_fusion = False
        config.rotary_percent = 0.5
        config.attention_dropout = 0.0
        config.apply_rope_fusion = False
        config.activation_func = squared_relu
        config.ffn_hidden_size = 16384
        config.masked_softmax_fusion = True
        config.attention_softmax_in_fp32 = True
        config.num_query_groups = 32
        config.kv_channels = 128
        config.rotary_interleaved = False
    elif config.my_model_type == "llama3_8b":
        config.activation_func = torch.nn.functional.silu
        config.add_bias_linear = False
        config.bias_activation_fusion = False
        config.gated_linear_unit = True
        config.apply_query_key_layer_scaling = True
        config.layernorm_zero_centered_gamma = (
            False  # Zero centered gamma not supported for RMSNorm
        )
        config.bias_dropout_fusion = False
        config.te_attn_mask_type = None
        config.rotary_percent = 0.5
        config.apply_rope_fusion = False
        config.attention_softmax_in_fp32 = True
        config.ffn_hidden_size = 14336

    return config


def get_vision_model_config(config, apply_query_key_layer_scaling=False):
    config.num_layers = 24
    config.num_attention_heads = 16
    config.add_bias_linear = True
    config.add_qkv_bias = True
    config.hidden_size = 1024
    config.hidden_dropout = 0.0
    config.attention_dropout = 0.0
    config.ffn_hidden_size = 4096
    config.gated_linear_unit = False
    config.activation_func = quick_gelu
    config.kv_channels = 64
    config.num_attention_heads = 16
    config.num_query_groups = 16
    config.layernorm_zero_centered_gamma = False
    config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
    config.bias_activation_fusion = False
    config.bias_dropout_fusion = False
    config.attention_softmax_in_fp32 = True

    return config


def get_vision_projection_config(config, hidden_size):
    config.gated_linear_unit = False
    config.bias_activation_fusion = False
    config.add_bias_linear = False
    config.hidden_size = hidden_size
    if config.language_model_type == "2b":
        config.ffn_hidden_size = 5440
        config.activation_func = torch.nn.functional.gelu
    if config.language_model_type == "8b":
        config.ffn_hidden_size = 16384
        config.activation_func = squared_relu
    elif config.language_model_type == "llama3_8b":
        config.ffn_hidden_size = 14336
        config.activation_func = torch.nn.functional.silu

    return config