Unverified Commit 89766b3d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

up (#14258)

parent bd21ed40
...@@ -153,7 +153,7 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro ...@@ -153,7 +153,7 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
feature_extractor.conv_layers[layer_id].conv.bias.data = value feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name: elif "weight" in name:
if value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape: if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError( raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
) )
...@@ -163,14 +163,14 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro ...@@ -163,14 +163,14 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
if "bias" in name: if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape: if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError( raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found." f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
) )
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name: elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape: if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError( raise ValueError(
f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
) )
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment