test_modelopt_gpt_model.py 1.79 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.modelopt_support.gpt.model_specs import get_gpt_layer_modelopt_spec
from megatron.core.inference.modelopt_support.gpt.state_dict_hooks import (
    mcore_gpt_load_te_state_dict_pre_hook,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class TestModelOptGPTModel:

    def setup_method(self, method):
        Utils.initialize_model_parallel(1, 1)
        model_parallel_cuda_manual_seed(123)
        transformer_config = TransformerConfig(
            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
        )
        self.gpt_model = GPTModel(
            config=transformer_config,
            transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(),
            vocab_size=100,
            max_sequence_length=4,
        )
        # Ensure that a GPTModel can be built with the modelopt spec.
        self.modelopt_gpt_model = GPTModel(
            config=transformer_config,
            transformer_layer_spec=get_gpt_layer_modelopt_spec(),
            vocab_size=100,
            max_sequence_length=4,
        )

    def test_load_te_state_dict_pre_hook(self):
        handle = self.modelopt_gpt_model._register_load_state_dict_pre_hook(
            mcore_gpt_load_te_state_dict_pre_hook
        )
        self.modelopt_gpt_model.load_state_dict(self.gpt_model.state_dict())
        handle.remove()

    def teardown_method(self, method):
        Utils.destroy_model_parallel()