Unverified Commit fc185200 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Use public API instead of removed private function in `te_llama.py` (#1856)

Use public API instead of removed private function
* replaced use of _load_state_dict_into_model with model.load_state_dict because the private function _load_state_dict_into_model was removed in https://github.com/huggingface/transformers/pull/36335

Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
parent f519e6e0
......@@ -19,7 +19,7 @@ from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaConfig,
)
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.modeling_utils import _add_variant, load_state_dict
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files
......@@ -148,8 +148,8 @@ class TELlamaForCausalLM:
state_dict = load_state_dict(shard_file)
# replace_params copies parameters relevant only to TransformerEngine
replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
# load_state_dict copies parameters other than those in TransformerEngine
vanilla_model.load_state_dict(state_dict, strict=False)
# Force mem release. Taken from huggingface code
del state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment