llama.py 4.27 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from nanotron.config import (
    AllForwardAllBackwardPipelineEngine,
    CheckpointsArgs,
    Config,
    DataArgs,
    DatasetStageArgs,
    GeneralArgs,
    LlamaConfig,
    LoggingArgs,
    LRSchedulerArgs,
    ModelArgs,
    OptimizerArgs,
    ParallelismArgs,
    TensorParallelLinearMode,
    TokenizerArgs,
    TokensArgs,
)
from nanotron.config.config import PretrainDatasetsArgs
from nanotron.models import build_model
from nanotron.models.llama import LlamaForTraining
from nanotron.parallel.context import ParallelContext
from nanotron.trainer import mark_tied_parameters

TINY_LLAMA_CONFIG = LlamaConfig(
    **{
        "bos_token_id": 1,
        "eos_token_id": 2,
        "hidden_act": "silu",
        "hidden_size": 16,
        "initializer_range": 0.02,
        "intermediate_size": 32,
        "is_llama_config": True,
        "max_position_embeddings": 128,
        "num_attention_heads": 8,
        "num_hidden_layers": 4,
        "num_key_value_heads": 4,
        "pad_token_id": None,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-06,
        "rope_scaling": None,
        "tie_word_embeddings": False,
        "use_cache": True,
        "vocab_size": 4096,
    }
)


def get_llama_training_config(model_config: ModelArgs):
    return Config(
        model=model_config,
        general=GeneralArgs(project="unittest", run="sanity_llama", seed=42),
        checkpoints=CheckpointsArgs(
            checkpoints_path="./checkpoints",
            checkpoint_interval=10,
        ),
        parallelism=ParallelismArgs(
            dp=1,
            pp=1,
            tp=2,
            expert_parallel_size=2,
            pp_engine="1f1b",
            tp_mode="ALL_REDUCE",
            tp_linear_async_communication=False,
        ),
        tokenizer=TokenizerArgs("gpt2"),
        optimizer=OptimizerArgs(
            zero_stage=0,
            weight_decay=0.01,
            clip_grad=1.0,
            accumulate_grad_in_fp32=False,
            adam_eps=1e-08,
            adam_beta1=0.9,
            adam_beta2=0.95,
            torch_adam_is_fused=True,
            learning_rate_scheduler=LRSchedulerArgs(
                learning_rate=3e-4,
                lr_warmup_steps=100,
                lr_warmup_style="linear",
                lr_decay_style="cosine",
                min_decay_lr=1e-5,
            ),
        ),
        logging=LoggingArgs(),
        tokens=TokensArgs(sequence_length=16, train_steps=10, micro_batch_size=16, batch_accumulation_per_replica=1),
        data_stages=[
            DatasetStageArgs(
                name="train",
                start_training_step=1,
                data=DataArgs(
                    seed=42,
                    num_loading_workers=1,
                    dataset=PretrainDatasetsArgs(
                        hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small",
                        hf_dataset_splits="train",
                        text_column_name="completion",
                        dataset_processing_num_proc_per_process=12,
                    ),
                ),
            )
        ],
    )


def create_llama_from_config(
    model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext
) -> LlamaForTraining:

    """
    Creates and returns a nanotron model.
    If `model_config` is None, then `checkpoint_path` must be set, in which case
    the configuration will be loaded from such path.
    If `checkpoint_path` is None, then `model_config` must be set, in which case
    the model created will have random weights.
    """

    parallel_config = ParallelismArgs(
        dp=parallel_context.data_parallel_size,
        pp=parallel_context.pipeline_parallel_size,
        tp=parallel_context.tensor_parallel_size,
        pp_engine=AllForwardAllBackwardPipelineEngine(),
        tp_mode=TensorParallelLinearMode.ALL_REDUCE,
        tp_linear_async_communication=False,
    )
    model = build_model(
        model_builder=lambda: LlamaForTraining(
            config=model_config,
            parallel_context=parallel_context,
            parallel_config=parallel_config,
            random_states=None,
        ),
        parallel_context=parallel_context,
        dtype=torch.bfloat16,
        device=device,
    )
    mark_tied_parameters(model=model, parallel_context=parallel_context)
    return model