test_sanity.py 5.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pathlib
import sys
import pytest
import torch
import transformer_engine
10
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
11
12
13
14
15
16
17
18
19
20
21

_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig

model_configs = {
    "small": ModelConfig(2, 10, 2, 16),
}


@pytest.mark.parametrize("model", ["small"])
22
23
24
@pytest.mark.parametrize(
    "module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def test_current_device(model, module):
    """Test cases where current device is different from tensor device"""

    num_devices = torch.cuda.device_count()
    assert num_devices > 1, "This test requires more than one GPU!"
    tensor_device = num_devices - 1
    dtype = torch.bfloat16
    config = model_configs[model]

    args = []
    kwargs = {}
    bwd_args = []
    if module == "TransformerLayer":
        model = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_heads,
            params_dtype=dtype,
            attn_input_format="thd",
            self_attn_mask_type="padding",
            device=f"cuda:{tensor_device}",
        )
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        seqlens_q = torch.randint(
            1,
            config.max_seqlen_q,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_q = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
        seqlens_kv = torch.randint(
            1,
            config.max_seqlen_kv,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_kv = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
        num_tokens = cu_seqlens_q[-1]
70
71
72
73
74
75
76
77
78
79
80
81
        args = [
            torch.randn(
                (num_tokens, config.hidden_size),
                dtype=dtype,
                device=f"cuda:{tensor_device}",
                requires_grad=True,
            )
        ]
        kwargs["cu_seqlens_q"] = cu_seqlens_q
        kwargs["cu_seqlens_kv"] = cu_seqlens_kv
        kwargs["max_seqlen_q"] = config.max_seqlen_q
        kwargs["max_seqlen_kv"] = config.max_seqlen_kv
82
    elif module == "DotProductAttention":
83
84
85
        model = DotProductAttention(
            config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
        )
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        seqlens_q = torch.randint(
            1,
            config.max_seqlen_q,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_q = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
        seqlens_kv = torch.randint(
            1,
            config.max_seqlen_kv,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_kv = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
        num_tokens = cu_seqlens_q[-1]
109
110
111
112
113
114
        args = [
            torch.randn(
                num_tokens,
                config.num_heads,
                config.head_dim_qk,
                dtype=dtype,
115
                device=f"cuda:{tensor_device}",
116
117
118
119
120
121
122
123
                requires_grad=True,
            )
            for _ in range(3)
        ]
        kwargs["cu_seqlens_q"] = cu_seqlens_q
        kwargs["cu_seqlens_kv"] = cu_seqlens_kv
        kwargs["max_seqlen_q"] = config.max_seqlen_q
        kwargs["max_seqlen_kv"] = config.max_seqlen_kv
124
125
126
        bwd_args = [
            torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
        ]
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    elif module == "Linear":
        model = Linear(
            config.hidden_size,
            4 * config.hidden_size,
            params_dtype=dtype,
            device=f"cuda:{tensor_device}",
        )
        args = [
            torch.randn(
                (config.max_seqlen_q, config.batch_size, config.hidden_size),
                dtype=dtype,
                device=f"cuda:{tensor_device}",
                requires_grad=True,
            )
        ]
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    elif module == "GroupedLinear":
        num_gemms = 4
        model = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            params_dtype=dtype,
            device=f"cuda:{tensor_device}",
        )
        args = [
            torch.randn(
                (config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
                dtype=dtype,
                device=f"cuda:{tensor_device}",
                requires_grad=True,
            ),
            [0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1),  # Empty first split.
        ]
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    current_device_before = torch.cuda.current_device()
    out = model(*args, **kwargs)
    if module == "DotProductAttention":
        out.backward(*bwd_args)
    else:
        loss = out.sum()
        loss.backward()
    current_device_after = torch.cuda.current_device()
    tensor_device_out = out.get_device()
    tensor_device_grad = args[0].grad.get_device()

    assert (
        current_device_after == current_device_before
    ), "The current device should not have changed!"
    assert (
        tensor_device_out == tensor_device
    ), "The output tensor should be the same as the input tensors!"
    assert (
        tensor_device_grad == tensor_device
    ), "The gradient tensor should be the same as the input tensors!"