Unverified Commit fef3a6b6 authored by Ben Barsdell's avatar Ben Barsdell Committed by GitHub
Browse files

Restore torch defaults between sgl-kernel tests (#11131)

parent 173e0f70
import pytest
import torch
# This ensures the torch defaults don't get left in modified states between
# tests (e.g., when a test fails before restoring the original value), which
# can cause subsequent tests to fail.
@pytest.fixture(autouse=True)
def reset_torch_defaults():
orig_default_device = torch.get_default_device()
orig_default_dtype = torch.get_default_dtype()
yield
torch.set_default_dtype(orig_default_dtype)
torch.set_default_device(orig_default_device)
import pytest
import torch
def test_change_torch_defaults():
torch.set_default_device("cpu:0")
torch.set_default_dtype(torch.float16)
def test_check_torch_defaults():
assert torch.get_default_device() == torch.device("cpu")
assert torch.get_default_dtype() == torch.float32
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment