"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "caf1d116a62a324a2b0ccfd92ca6c095d5368dde"
Unverified Commit afa1ef09 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling_utils] use less cpu memory with sharded checkpoint loading (#16844)

* less cpu memory with sharded checkpoint loading

* Trigger CI

* Trigger CI
parent e13a91fe
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import json import json
import os import os
import re import re
...@@ -2149,6 +2150,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2149,6 +2150,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
# force memory release
del state_dict
gc.collect()
if len(error_msgs) > 0: if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs) error_msg = "\n\t".join(error_msgs)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment