Commit 04629132 authored by zhuwenwen's avatar zhuwenwen
Browse files

[tests] fix tests

parent 07c69390
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import itertools import itertools
from functools import partial from functools import partial
import os
import pytest import pytest
from PIL import Image from PIL import Image
from pqdm.threads import pqdm from pqdm.threads import pqdm
...@@ -12,6 +13,7 @@ from vllm.multimodal.parse import ImageSize ...@@ -12,6 +13,7 @@ from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from ...utils import build_model_context from ...utils import build_model_context
from ....utils import models_path_prefix
def _validate_image_max_tokens_one( def _validate_image_max_tokens_one(
...@@ -33,7 +35,7 @@ def _validate_image_max_tokens_one( ...@@ -33,7 +35,7 @@ def _validate_image_max_tokens_one(
@pytest.mark.skip("This test takes around 5 minutes to run. " @pytest.mark.skip("This test takes around 5 minutes to run. "
"Comment this out to run it manually.") "Comment this out to run it manually.")
@pytest.mark.parametrize("model_id", @pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) [os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf")])
def test_processor_max_tokens(model_id): def test_processor_max_tokens(model_id):
ctx = build_model_context( ctx = build_model_context(
model_id, model_id,
...@@ -127,7 +129,7 @@ def _test_image_prompt_replacements( ...@@ -127,7 +129,7 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id", @pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) [os.path.join(models_path_prefix, "llava-hf/llava-onevision-qwen2-0.5b-ov-hf")])
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(model_id, num_imgs): def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context( ctx = build_model_context(
...@@ -180,4 +182,4 @@ def test_processor_prompt_replacements_all(model_id, num_imgs): ...@@ -180,4 +182,4 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
processor, processor,
num_imgs=num_imgs, num_imgs=num_imgs,
image_sizes=image_sizes, image_sizes=image_sizes,
) )
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for mllama's multimodal preprocessing and profiling.""" """Tests for mllama's multimodal preprocessing and profiling."""
import os
import pytest import pytest
from transformers import MllamaConfig from transformers import MllamaConfig
...@@ -7,10 +8,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -7,10 +8,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from ...utils import build_model_context from ...utils import build_model_context
from ....utils import models_path_prefix
@pytest.mark.parametrize("model_id", @pytest.mark.parametrize("model_id",
["meta-llama/Llama-3.2-11B-Vision-Instruct"]) [os.path.join(models_path_prefix, "meta-llama/Llama-3.2-11B-Vision-Instruct")])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072]) @pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
@pytest.mark.parametrize("max_num_seqs", [1, 2, 8]) @pytest.mark.parametrize("max_num_seqs", [1, 2, 8])
def test_profiling( def test_profiling(
...@@ -68,4 +70,4 @@ def test_profiling( ...@@ -68,4 +70,4 @@ def test_profiling(
# simulate mllama image-present prefill. # simulate mllama image-present prefill.
for actual_len, last_group_len in zip(actual_encoder_seq_lens, for actual_len, last_group_len in zip(actual_encoder_seq_lens,
encoder_seq_lens): encoder_seq_lens):
assert actual_len >= last_group_len assert actual_len >= last_group_len
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for phi3v's multimodal preprocessing kwargs.""" """Tests for phi3v's multimodal preprocessing kwargs."""
import os
import pytest import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets from ....conftest import _ImageAssets
from ...utils import build_model_context from ...utils import build_model_context
from ....utils import models_path_prefix
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) @pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct")])
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"), ("mm_processor_kwargs", "expected_toks_per_img"),
...@@ -50,4 +52,4 @@ def test_processor_override( ...@@ -50,4 +52,4 @@ def test_processor_override(
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for phi4mm's multimodal preprocessing kwargs.""" """Tests for phi4mm's multimodal preprocessing kwargs."""
import os
import pytest import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets from ....conftest import _ImageAssets
from ...utils import build_model_context from ...utils import build_model_context
from ....utils import models_path_prefix
@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"]) @pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "microsoft/Phi-4-multimodal-instruct")])
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"), ("mm_processor_kwargs", "expected_toks_per_img"),
...@@ -56,4 +58,4 @@ def test_processor_override( ...@@ -56,4 +58,4 @@ def test_processor_override(
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
img_tok_count = processed_inputs["prompt_token_ids"].count( img_tok_count = processed_inputs["prompt_token_ids"].count(
_IMAGE_PLACEHOLDER_TOKEN_ID) _IMAGE_PLACEHOLDER_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import pytest import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets from ....conftest import _ImageAssets
from ...utils import build_model_context from ...utils import build_model_context
from ....utils import models_path_prefix
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) @pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "Qwen/Qwen2-VL-2B-Instruct")])
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [ ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [
...@@ -51,4 +53,4 @@ def test_processor_override( ...@@ -51,4 +53,4 @@ def test_processor_override(
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
assert pixel_shape[1] == expected_pixels_shape[1] assert pixel_shape[1] == expected_pixels_shape[1]
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for smolvlm's multimodal preprocessing kwargs.""" """Tests for smolvlm's multimodal preprocessing kwargs."""
import os
import pytest import pytest
from transformers import SmolVLMConfig from transformers import SmolVLMConfig
...@@ -7,9 +8,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -7,9 +8,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets from ....conftest import _ImageAssets
from ...utils import build_model_context from ...utils import build_model_context
from ....utils import models_path_prefix
@pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"]) @pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "HuggingFaceTB/SmolVLM2-2.2B-Instruct")])
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"), ("mm_processor_kwargs", "expected_toks_per_img"),
...@@ -62,4 +64,4 @@ def test_processor_override( ...@@ -62,4 +64,4 @@ def test_processor_override(
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
image_token_id = ctx.get_hf_config().image_token_id image_token_id = ctx.get_hf_config().image_token_id
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
\ No newline at end of file
...@@ -8,7 +8,9 @@ import os ...@@ -8,7 +8,9 @@ import os
import pytest import pytest
from packaging.version import Version from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION from transformers import __version__ as TRANSFORMERS_VERSION
# from ..utils import models_path_prefix
models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH")
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -109,8 +111,6 @@ class _HfExamplesInfo: ...@@ -109,8 +111,6 @@ class _HfExamplesInfo:
pytest.skip(msg) pytest.skip(msg)
models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH")
# yapf: disable # yapf: disable
_TEXT_GENERATION_EXAMPLE_MODELS = { _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Decoder-only] # [Decoder-only]
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import warnings import warnings
import os
import pytest import pytest
import torch.cuda import torch.cuda
...@@ -20,6 +21,8 @@ from vllm.platforms import current_platform ...@@ -20,6 +21,8 @@ from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH")
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch): def test_registry_imports(model_arch):
...@@ -52,12 +55,12 @@ def test_registry_imports(model_arch): ...@@ -52,12 +55,12 @@ def test_registry_imports(model_arch):
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ @pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
("LlamaForCausalLM", False, False, False), (os.path.join(models_path_prefix, "LlamaForCausalLM"), False, False, False),
("MllamaForConditionalGeneration", True, False, False), (os.path.join(models_path_prefix, "MllamaForConditionalGeneration"), True, False, False),
("LlavaForConditionalGeneration", True, True, False), (os.path.join(models_path_prefix, "LlavaForConditionalGeneration"), True, True, False),
("BertForSequenceClassification", False, False, True), (os.path.join(models_path_prefix, "BertForSequenceClassification"), False, False, True),
("RobertaForSequenceClassification", False, False, True), (os.path.join(models_path_prefix, "RobertaForSequenceClassification"), False, False, True),
("XLMRobertaForSequenceClassification", False, False, True), (os.path.join(models_path_prefix, "XLMRobertaForSequenceClassification"), False, False, True),
]) ])
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
...@@ -77,9 +80,9 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): ...@@ -77,9 +80,9 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ @pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
("MLPSpeculatorPreTrainedModel", False, False), (os.path.join(models_path_prefix, "MLPSpeculatorPreTrainedModel"), False, False),
("DeepseekV2ForCausalLM", True, False), (os.path.join(models_path_prefix, "DeepseekV2ForCausalLM"), True, False),
("Qwen2VLForConditionalGeneration", True, True), (os.path.join(models_path_prefix, "Qwen2VLForConditionalGeneration"), True, True),
]) ])
def test_registry_is_pp(model_arch, is_pp, init_cuda): def test_registry_is_pp(model_arch, is_pp, init_cuda):
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp
...@@ -104,4 +107,4 @@ def test_hf_registry_coverage(): ...@@ -104,4 +107,4 @@ def test_hf_registry_coverage():
assert not untested_archs, ( assert not untested_archs, (
"Please add the following architectures to " "Please add the following architectures to "
f"`tests/models/registry.py`: {untested_archs}") f"`tests/models/registry.py`: {untested_archs}")
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Expanded quantized model tests for CPU offloading # Expanded quantized model tests for CPU offloading
# Base tests: tests/basic_correctness/test_cpu_offload.py # Base tests: tests/basic_correctness/test_cpu_offload.py
import pytest import pytest
import os import os
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from ..utils import compare_two_settings, models_path_prefix from ..utils import compare_two_settings, models_path_prefix
from vllm.platforms import current_platform from vllm.platforms import current_platform
@pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(), @pytest.mark.skipif(not is_quant_method_supported("fp8") or current_platform.is_rocm(),
reason="fp8 is not supported on this GPU type.") reason="fp8 is not supported on this GPU type.")
def test_cpu_offload_fp8(): def test_cpu_offload_fp8():
# Test quantization of an unquantized checkpoint # Test quantization of an unquantized checkpoint
compare_two_settings(os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), compare_two_settings(os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
["--quantization", "fp8"], ["--quantization", "fp8"],
["--quantization", "fp8", "--cpu-offload-gb", "1"], ["--quantization", "fp8", "--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
# Test loading a quantized checkpoint # Test loading a quantized checkpoint
# compare_two_settings(os.path.join(models_path_prefix, "neuralmagic/Qwen2-1.5B-Instruct-FP8"), [], # compare_two_settings(os.path.join(models_path_prefix, "neuralmagic/Qwen2-1.5B-Instruct-FP8"), [],
# ["--cpu-offload-gb", "1"], # ["--cpu-offload-gb", "1"],
# max_wait_seconds=480) # max_wait_seconds=480)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") or current_platform.is_rocm(), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") or current_platform.is_rocm(),
reason="gptq_marlin is not supported on this GPU type.") reason="gptq_marlin is not supported on this GPU type.")
def test_cpu_offload_gptq(monkeypatch): def test_cpu_offload_gptq(monkeypatch):
# This quant method is sensitive to dummy weights, so we force real weights # This quant method is sensitive to dummy weights, so we force real weights
monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto')
# Test GPTQ Marlin # Test GPTQ Marlin
compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4"), [], compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4"), [],
["--cpu-offload-gb", "1"], ["--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
# Test GPTQ # Test GPTQ
compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4"), compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4"),
["--quantization", "gptq"], ["--quantization", "gptq"],
["--quantization", "gptq", "--cpu-offload-gb", "1"], ["--quantization", "gptq", "--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
@pytest.mark.skipif(not is_quant_method_supported("awq_marlin") or current_platform.is_rocm(), @pytest.mark.skipif(not is_quant_method_supported("awq_marlin") or current_platform.is_rocm(),
reason="awq_marlin is not supported on this GPU type.") reason="awq_marlin is not supported on this GPU type.")
def test_cpu_offload_awq(monkeypatch): def test_cpu_offload_awq(monkeypatch):
# This quant method is sensitive to dummy weights, so we force real weights # This quant method is sensitive to dummy weights, so we force real weights
monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto')
# Test AWQ Marlin # Test AWQ Marlin
compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-AWQ"), [], compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-AWQ"), [],
["--cpu-offload-gb", "1"], ["--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
# Test AWQ # Test AWQ
compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-AWQ"), compare_two_settings(os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct-AWQ"),
["--quantization", "awq"], ["--quantization", "awq"],
["--quantization", "awq", "--cpu-offload-gb", "1"], ["--quantization", "awq", "--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") or current_platform.is_rocm(), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") or current_platform.is_rocm(),
reason="gptq_marlin is not supported on this GPU type.") reason="gptq_marlin is not supported on this GPU type.")
def test_cpu_offload_compressed_tensors(monkeypatch): def test_cpu_offload_compressed_tensors(monkeypatch):
# This quant method is sensitive to dummy weights, so we force real weights # This quant method is sensitive to dummy weights, so we force real weights
monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto')
# Test wNa16 # Test wNa16
compare_two_settings(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-channel-v2"), [], compare_two_settings(os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w4a16-channel-v2"), [],
["--cpu-offload-gb", "1"], ["--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
# Test w4a16_marlin24 # Test w4a16_marlin24
compare_two_settings(os.path.join(models_path_prefix, "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"), compare_two_settings(os.path.join(models_path_prefix, "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"),
[], ["--cpu-offload-gb", "1"], [], ["--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
# Test w8a8 # Test w8a8
compare_two_settings( compare_two_settings(
os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"), [], os.path.join(models_path_prefix, "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"), [],
["--cpu-offload-gb", "1"], ["--cpu-offload-gb", "1"],
max_wait_seconds=480) max_wait_seconds=480)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import vllm import vllm
def test_embedded_commit_defined(): def test_embedded_commit_defined():
assert hasattr(vllm, "__version__") assert hasattr(vllm, "__version__")
assert hasattr(vllm, "__version_tuple__") assert hasattr(vllm, "__version_tuple__")
assert vllm.__version__ != "dev" assert vllm.__version__ != "dev"
assert vllm.__version_tuple__ != (0, 0, "dev") assert vllm.__version_tuple__ != (0, 0, "dev")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
import numpy as np import numpy as np
import torch import torch
from vllm.platforms.interface import Platform from vllm.platforms.interface import Platform
def test_seed_behavior(): def test_seed_behavior():
# Test with a specific seed # Test with a specific seed
Platform.seed_everything(42) Platform.seed_everything(42)
random_value_1 = random.randint(0, 100) random_value_1 = random.randint(0, 100)
np_random_value_1 = np.random.randint(0, 100) np_random_value_1 = np.random.randint(0, 100)
torch_random_value_1 = torch.randint(0, 100, (1, )).item() torch_random_value_1 = torch.randint(0, 100, (1, )).item()
Platform.seed_everything(42) Platform.seed_everything(42)
random_value_2 = random.randint(0, 100) random_value_2 = random.randint(0, 100)
np_random_value_2 = np.random.randint(0, 100) np_random_value_2 = np.random.randint(0, 100)
torch_random_value_2 = torch.randint(0, 100, (1, )).item() torch_random_value_2 = torch.randint(0, 100, (1, )).item()
assert random_value_1 == random_value_2 assert random_value_1 == random_value_2
assert np_random_value_1 == np_random_value_2 assert np_random_value_1 == np_random_value_2
assert torch_random_value_1 == torch_random_value_2 assert torch_random_value_1 == torch_random_value_2
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment