Unverified Commit 83dac8cf authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add weights_only=False for torch.load (#1374)



add weights_only=False for torch.load
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 7f5c784e
...@@ -339,7 +339,7 @@ class TestFloat8Tensor: ...@@ -339,7 +339,7 @@ class TestFloat8Tensor:
del x_fp8, byte_stream del x_fp8, byte_stream
# Deserialize tensor # Deserialize tensor
x_fp8 = torch.load(io.BytesIO(x_bytes)) x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False)
del x_bytes del x_bytes
# Check results # Check results
......
...@@ -1101,7 +1101,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False ...@@ -1101,7 +1101,7 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
del block del block
block = get_model(dtype, config) block = get_model(dtype, config)
block.load_state_dict(torch.load(path)) block.load_state_dict(torch.load(path, weights_only=False))
torch.set_rng_state(_cpu_rng_state_new) torch.set_rng_state(_cpu_rng_state_new)
torch.cuda.set_rng_state(_cuda_rng_state_new) torch.cuda.set_rng_state(_cuda_rng_state_new)
......
...@@ -124,7 +124,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd ...@@ -124,7 +124,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
torch.save(model_in.state_dict(), tmp_filename) torch.save(model_in.state_dict(), tmp_filename)
model_out = Test_TE_Export(precision, True) model_out = Test_TE_Export(precision, True)
model_out.load_state_dict(torch.load(tmp_filename)) model_out.load_state_dict(torch.load(tmp_filename, weights_only=False))
model_out.eval() model_out.eval()
# scaling fwd # scaling fwd
...@@ -263,7 +263,7 @@ def test_fp8_model_checkpoint( ...@@ -263,7 +263,7 @@ def test_fp8_model_checkpoint(
# to load the fp8 metadata before loading tensors. # to load the fp8 metadata before loading tensors.
# #
# Load checkpoint # Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes))) model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False))
del model_bytes del model_bytes
# Check that loaded model matches saved model # Check that loaded model matches saved model
...@@ -450,7 +450,7 @@ def test_sequential_model( ...@@ -450,7 +450,7 @@ def test_sequential_model(
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)
# Load checkpoint # Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes))) model.load_state_dict(torch.load(io.BytesIO(model_bytes), weights_only=False))
del model_bytes del model_bytes
# Check that new model's FP8 metadata matches saved model # Check that new model's FP8 metadata matches saved model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment