"web/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "a30526aec1f7c6fb563b5c49503e9d18f6fa73d2"
Commit 868de8d1 authored by thomwolf's avatar thomwolf
Browse files

updating weights loading

parent 64e0adda
...@@ -653,8 +653,13 @@ class BertPreTrainedModel(nn.Module): ...@@ -653,8 +653,13 @@ class BertPreTrainedModel(nn.Module):
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) if from_tf:
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) # Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
...@@ -708,24 +713,24 @@ class BertPreTrainedModel(nn.Module): ...@@ -708,24 +713,24 @@ class BertPreTrainedModel(nn.Module):
# with tarfile.open(resolved_archive_file, 'r:gz') as archive: # with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir) # archive.extractall(tempdir)
# serialization_dir = tempdir # serialization_dir = tempdir
# config_file = os.path.join(serialization_dir, CONFIG_NAME)
# if not os.path.exists(config_file):
# # Backward compatibility with old naming format
# config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
# Load config # Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(resolved_config_file)
if not os.path.exists(config_file):
# Backward compatibility with old naming format
config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) # weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
# if tempdir: # if tempdir:
# # Clean up temp dir # # Clean up temp dir
# shutil.rmtree(tempdir) # shutil.rmtree(tempdir)
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) # weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
return load_tf_weights_in_bert(model, weights_path) return load_tf_weights_in_bert(model, weights_path)
# Load from a PyTorch state_dict # Load from a PyTorch state_dict
old_keys = [] old_keys = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment