test_sanity.py 6.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pathlib
import sys
import pytest
import torch
import transformer_engine
10
11
12
13
14
15
16
17
18
19
from transformer_engine.pytorch import (
    DotProductAttention,
    TransformerLayer,
    Linear,
    GroupedLinear,
    NVFP4Quantizer,
    autocast,
    is_nvfp4_available,
)
from transformer_engine.common import recipe
20
21
22
23
24
25
26
27
28

_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig

model_configs = {
    "small": ModelConfig(2, 10, 2, 16),
}

29
30
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)

31
32

@pytest.mark.parametrize("model", ["small"])
33
34
35
@pytest.mark.parametrize(
    "module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def test_current_device(model, module):
    """Test cases where current device is different from tensor device"""

    num_devices = torch.cuda.device_count()
    assert num_devices > 1, "This test requires more than one GPU!"
    tensor_device = num_devices - 1
    dtype = torch.bfloat16
    config = model_configs[model]

    args = []
    kwargs = {}
    bwd_args = []
    if module == "TransformerLayer":
        model = TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_heads,
            params_dtype=dtype,
            attn_input_format="thd",
            self_attn_mask_type="padding",
            device=f"cuda:{tensor_device}",
        )
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        seqlens_q = torch.randint(
            1,
            config.max_seqlen_q,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_q = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
        seqlens_kv = torch.randint(
            1,
            config.max_seqlen_kv,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_kv = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
        num_tokens = cu_seqlens_q[-1]
81
82
83
84
85
86
87
88
89
90
91
92
        args = [
            torch.randn(
                (num_tokens, config.hidden_size),
                dtype=dtype,
                device=f"cuda:{tensor_device}",
                requires_grad=True,
            )
        ]
        kwargs["cu_seqlens_q"] = cu_seqlens_q
        kwargs["cu_seqlens_kv"] = cu_seqlens_kv
        kwargs["max_seqlen_q"] = config.max_seqlen_q
        kwargs["max_seqlen_kv"] = config.max_seqlen_kv
93
    elif module == "DotProductAttention":
94
95
96
        model = DotProductAttention(
            config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
        )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        seqlens_q = torch.randint(
            1,
            config.max_seqlen_q,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_q = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
        seqlens_kv = torch.randint(
            1,
            config.max_seqlen_kv,
            [config.batch_size],
            dtype=torch.int32,
            device=f"cuda:{tensor_device}",
        )
        cu_seqlens_kv = torch.zeros(
            config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
        )
        cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
        num_tokens = cu_seqlens_q[-1]
120
121
122
123
124
125
        args = [
            torch.randn(
                num_tokens,
                config.num_heads,
                config.head_dim_qk,
                dtype=dtype,
126
                device=f"cuda:{tensor_device}",
127
128
129
130
131
132
133
134
                requires_grad=True,
            )
            for _ in range(3)
        ]
        kwargs["cu_seqlens_q"] = cu_seqlens_q
        kwargs["cu_seqlens_kv"] = cu_seqlens_kv
        kwargs["max_seqlen_q"] = config.max_seqlen_q
        kwargs["max_seqlen_kv"] = config.max_seqlen_kv
135
136
137
        bwd_args = [
            torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
        ]
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    elif module == "Linear":
        model = Linear(
            config.hidden_size,
            4 * config.hidden_size,
            params_dtype=dtype,
            device=f"cuda:{tensor_device}",
        )
        args = [
            torch.randn(
                (config.max_seqlen_q, config.batch_size, config.hidden_size),
                dtype=dtype,
                device=f"cuda:{tensor_device}",
                requires_grad=True,
            )
        ]
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    elif module == "GroupedLinear":
        num_gemms = 4
        model = GroupedLinear(
            num_gemms,
            config.hidden_size,
            4 * config.hidden_size,
            params_dtype=dtype,
            device=f"cuda:{tensor_device}",
        )
        args = [
            torch.randn(
                (config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
                dtype=dtype,
                device=f"cuda:{tensor_device}",
                requires_grad=True,
            ),
            [0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1),  # Empty first split.
        ]
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    current_device_before = torch.cuda.current_device()
    out = model(*args, **kwargs)
    if module == "DotProductAttention":
        out.backward(*bwd_args)
    else:
        loss = out.sum()
        loss.backward()
    current_device_after = torch.cuda.current_device()
    tensor_device_out = out.get_device()
    tensor_device_grad = args[0].grad.get_device()

    assert (
        current_device_after == current_device_before
    ), "The current device should not have changed!"
    assert (
        tensor_device_out == tensor_device
    ), "The output tensor should be the same as the input tensors!"
    assert (
        tensor_device_grad == tensor_device
    ), "The gradient tensor should be the same as the input tensors!"
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


@pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4)
def test_nvfp4_rht_cache():
    """Ensure correct RHT cache for NVFP4."""

    num_devices = torch.cuda.device_count()
    assert num_devices > 1, "This test requires more than one GPU!"

    # Populate cache on last device.
    with torch.cuda.device(num_devices - 1):
        _ = NVFP4Quantizer()

    hidden_size = 128
    dtype = torch.bfloat16

    model = Linear(hidden_size, hidden_size, params_dtype=dtype)
    inp = torch.randn(hidden_size, hidden_size, device=torch.cuda.current_device(), dtype=dtype)
    fp4_recipe = recipe.NVFP4BlockScaling()
    with autocast(recipe=fp4_recipe):
        _ = model(inp)