Unverified Commit afa1ef09 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling_utils] use less cpu memory with sharded checkpoint loading (#16844)

* less cpu memory with sharded checkpoint loading

* Trigger CI

* Trigger CI
parent e13a91fe
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
......@@ -2149,6 +2150,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
# force memory release
del state_dict
gc.collect()
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment