Unverified Commit 11c49ed2 authored by Harsh Trivedi's avatar Harsh Trivedi Committed by GitHub
Browse files

Load the state dict on CPU to prevent unnecessary GPU memory surge (#20920)

load the state dict on cpu.
parent 0b686a8a
......@@ -382,7 +382,7 @@ def load_sharded_checkpoint(model, folder, strict=True):
raise RuntimeError(error_message)
for shard_file in shard_files:
state_dict = torch.load(os.path.join(folder, shard_file))
state_dict = torch.load(os.path.join(folder, shard_file), map_location="cpu")
model.load_state_dict(state_dict, strict=False)
# Make sure memory is fred before we load the next state dict.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment