test_seed_behavior.py 748 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
import random

import numpy as np
import torch

from vllm.platforms.interface import Platform


def test_seed_behavior():
11
12
    # Test with a specific seed
    Platform.seed_everything(42)
13
14
15
16
    random_value_1 = random.randint(0, 100)
    np_random_value_1 = np.random.randint(0, 100)
    torch_random_value_1 = torch.randint(0, 100, (1, )).item()

17
    Platform.seed_everything(42)
18
19
20
21
    random_value_2 = random.randint(0, 100)
    np_random_value_2 = np.random.randint(0, 100)
    torch_random_value_2 = torch.randint(0, 100, (1, )).item()

22
23
24
    assert random_value_1 == random_value_2
    assert np_random_value_1 == np_random_value_2
    assert torch_random_value_1 == torch_random_value_2