"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8e777b3ba4f5f06b826f172c522ddd82e293f405"
Unverified Commit 11c49ed2 authored by Harsh Trivedi's avatar Harsh Trivedi Committed by GitHub
Browse files

Load the state dict on CPU to prevent unnecessary GPU memory surge (#20920)

load the state dict on cpu.
parent 0b686a8a
...@@ -382,7 +382,7 @@ def load_sharded_checkpoint(model, folder, strict=True): ...@@ -382,7 +382,7 @@ def load_sharded_checkpoint(model, folder, strict=True):
raise RuntimeError(error_message) raise RuntimeError(error_message)
for shard_file in shard_files: for shard_file in shard_files:
state_dict = torch.load(os.path.join(folder, shard_file)) state_dict = torch.load(os.path.join(folder, shard_file), map_location="cpu")
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Make sure memory is fred before we load the next state dict. # Make sure memory is fred before we load the next state dict.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment