"examples/offline_inference/basic.py" did not exist on "0b98ba15c744f1dfb0ea4f2135e85ca23d572ae1"
Unverified Commit 5d7e3d01 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361)

[mis][ci/test] fix flaky test in tests/test_sharded_state_loader.py (#5361)
parent 0373e183
...@@ -39,7 +39,8 @@ def test_filter_subtensors(): ...@@ -39,7 +39,8 @@ def test_filter_subtensors():
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
for key, tensor in filtered_state_dict.items(): for key, tensor in filtered_state_dict.items():
assert tensor.equal(state_dict[key]) # NOTE: don't use `euqal` here, as the tensor might contain NaNs
assert tensor is state_dict[key]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment