test_gpt_model.py 2.99 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
3
4
5
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import pytest
import torch

liangjing's avatar
liangjing committed
6
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
liangjing's avatar
v1  
liangjing committed
7
8
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
liangjing's avatar
liangjing committed
9
10
11
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils

liangjing's avatar
v1  
liangjing committed
12
13
14
15

class TestGPTModel:

    def setup_method(self, method):
liangjing's avatar
liangjing committed
16
        Utils.initialize_model_parallel(1, 1)
liangjing's avatar
v1  
liangjing committed
17
        model_parallel_cuda_manual_seed(123)
liangjing's avatar
liangjing committed
18
19
20
21
22
23
24
25
26
27
        transformer_config = TransformerConfig(
            num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True
        )
        self.gpt_model = GPTModel(
            config=transformer_config,
            transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(),
            vocab_size=100,
            max_sequence_length=4,
        )

liangjing's avatar
v1  
liangjing committed
28
    def teardown_method(self, method):
liangjing's avatar
liangjing committed
29
        Utils.destroy_model_parallel()
liangjing's avatar
v1  
liangjing committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    def test_constructor(self):
        assert isinstance(self.gpt_model, GPTModel)

        assert self.gpt_model.max_sequence_length == 4

        num_weights = sum([p.numel() for p in self.gpt_model.parameters()])
        assert num_weights == 6240

    def test_set_input_tensor(self):
        config: TransformerConfig = self.gpt_model.config
        sequence_length = self.gpt_model.max_sequence_length
        micro_batch_size = 2

        # [sequence length, batch size, hidden size]
        input_tensor = torch.ones((sequence_length, micro_batch_size, config.hidden_size))

        self.gpt_model.set_input_tensor(input_tensor)

        assert self.gpt_model.decoder.input_tensor.shape[0] == sequence_length
        assert self.gpt_model.decoder.input_tensor.shape[1] == micro_batch_size
        assert self.gpt_model.decoder.input_tensor.shape[2] == config.hidden_size

    def test_post_process_forward(self):
        config: TransformerConfig = self.gpt_model.config
        sequence_length = self.gpt_model.max_sequence_length
        micro_batch_size = 2

        self.gpt_model.cuda()

        data = list(range(sequence_length))
        input_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
        position_ids = torch.tensor(data, dtype=torch.int64).repeat((micro_batch_size, 1)).cuda()
liangjing's avatar
liangjing committed
63
64
65
        attention_mask = torch.ones(
            (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool
        ).cuda()
liangjing's avatar
v1  
liangjing committed
66

liangjing's avatar
liangjing committed
67
68
69
        logits = self.gpt_model.forward(
            input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask
        )
liangjing's avatar
v1  
liangjing committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

        assert logits.shape[0] == micro_batch_size
        assert logits.shape[1] == sequence_length
        assert logits.shape[2] == self.gpt_model.vocab_size

    def test_no_post_process_forward(self):
        pass

    def test_no_preprocess_forward(self):
        pass

    def test_state_dict_for_save_checkpoint(self):
        pass

    def test_load_state_dict(self):
        pass