test_base_model.py 1.77 KB
Newer Older
chenzk's avatar
v1.0.8  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pytest
import torch
import torch.distributed as dist
from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config
from helpers.utils import init_distributed, rerun_if_address_is_in_use
from nanotron.config import Config, ModelArgs, RandomInit
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from torch import nn


@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)])
@pytest.mark.skip
@rerun_if_address_is_in_use()
def test_get_named_modules_in_pp_rank(tp: int, dp: int, pp: int):
    model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG)
    config = get_llama_training_config(model_args)

    init_distributed(tp=tp, dp=dp, pp=pp)(_test_get_named_modules_in_pp_rank)(config=config)


def _test_get_named_modules_in_pp_rank(
    parallel_context: ParallelContext,
    config: Config,
):
    model = create_llama_from_config(
        model_config=config.model.model_config,
        device=torch.device("cuda"),
        parallel_context=parallel_context,
    )
    model.init_model_randomly(config=config)

    modules_that_not_in_current_pp_rank = {}
    current_pp_rank = dist.get_rank(group=parallel_context.pp_pg)
    for name, module in model.named_modules():
        if isinstance(module, PipelineBlock) and module.rank != current_pp_rank:
            modules_that_not_in_current_pp_rank[name] = module

    named_modules_in_pp_rank = model.named_modules_in_pp_rank

    for name, module in named_modules_in_pp_rank.items():
        # NOTE: if a module is in the current rank, we expect it to be an initialized module
        # not PipelineBlock
        assert isinstance(module, nn.Module)
        assert name not in modules_that_not_in_current_pp_rank