config.py 1.77 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from dataclasses import dataclass
from typing import Optional, Union

import torch
from nanotron.config import Config, ExistingCheckpointInit, NanotronConfigs
from nanotron.config.utils_config import cast_str_to_torch_dtype


@dataclass
class MambaInit:
    initializer_range: float = 0.02
    rescale_prenorm_residual: bool = True
    n_residuals_per_layer: int = 1  # Change to 2 if we have MLP


@dataclass
class ModelArgs:
    """Arguments related to model architecture"""

    model_config: NanotronConfigs
    init_method: Union[MambaInit, ExistingCheckpointInit]
    dtype: Optional[torch.dtype] = None
    make_vocab_size_divisible_by: int = 1
    ddp_bucket_cap_mb: int = 25

    def __post_init__(self):
        if self.dtype is None:
            self.dtype = torch.bfloat16
        if isinstance(self.dtype, str):
            self.dtype = cast_str_to_torch_dtype(self.dtype)

        # if self.model_config.max_position_embeddings is None:
        #     self.model_config.max_position_embeddings = 0


@dataclass(kw_only=True)  # pylint: disable=unexpected-keyword-arg
class MambaConfig(Config):
    """Main configuration class"""

    model: ModelArgs


@dataclass
class MambaModelConfig:
    """Configuration for a Mamba model

    Be careful on having a coherent typing as we use it to reconstruct the model from yaml
    """

    is_mamba_config: bool = True  # We use this help differentiate models in yaml/python conversion
    d_model: int = 2560
    num_hidden_layers: int = 64
    vocab_size: int = 50277
    ssm_cfg: Optional[dict] = None
    rms_norm: bool = True
    fused_add_norm: bool = True
    residual_in_fp32: bool = True
    pad_vocab_size_multiple: int = 8
    # ==== Custom ======
    dtype: str = "float32"
    rms_norm_eps: float = 1e-5
    pad_token_id: Optional[int] = None