Unverified Commit 1e0395e7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] ensure different LoRA ranks for text encoders can be properly handled (#4669)

* debugging starts

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging ends, but does it?

* more robustness.
parent 9141c1f9
...@@ -1245,6 +1245,7 @@ class LoraLoaderMixin: ...@@ -1245,6 +1245,7 @@ class LoraLoaderMixin:
if len(text_encoder_lora_state_dict) > 0: if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.") logger.info(f"Loading {prefix}.")
rank = {}
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention. # Convert from the old naming convention to the new naming convention.
...@@ -1283,10 +1284,17 @@ class LoraLoaderMixin: ...@@ -1283,10 +1284,17 @@ class LoraLoaderMixin:
f"{name}.out_proj.lora_linear_layer.down.weight" f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
rank = text_encoder_lora_state_dict[ for name, _ in text_encoder_attn_modules(text_encoder):
"text_model.encoder.layers.0.self_attn.out_proj.lora_linear_layer.up.weight" rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
].shape[1] rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
if network_alphas is not None: if network_alphas is not None:
alpha_keys = [ alpha_keys = [
...@@ -1344,7 +1352,7 @@ class LoraLoaderMixin: ...@@ -1344,7 +1352,7 @@ class LoraLoaderMixin:
text_encoder, text_encoder,
lora_scale=1, lora_scale=1,
network_alphas=None, network_alphas=None,
rank=4, rank: Union[Dict[str, int], int] = 4,
dtype=None, dtype=None,
patch_mlp=False, patch_mlp=False,
): ):
...@@ -1365,23 +1373,28 @@ class LoraLoaderMixin: ...@@ -1365,23 +1373,28 @@ class LoraLoaderMixin:
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
if isinstance(rank, dict):
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
else:
current_rank = rank
attn_module.q_proj = PatchedLoraProjection( attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
) )
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
attn_module.k_proj = PatchedLoraProjection( attn_module.k_proj = PatchedLoraProjection(
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
) )
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
attn_module.v_proj = PatchedLoraProjection( attn_module.v_proj = PatchedLoraProjection(
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
) )
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
attn_module.out_proj = PatchedLoraProjection( attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
) )
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
...@@ -1389,14 +1402,16 @@ class LoraLoaderMixin: ...@@ -1389,14 +1402,16 @@ class LoraLoaderMixin:
for name, mlp_module in text_encoder_mlp_modules(text_encoder): for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha") fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha") fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
mlp_module.fc1 = PatchedLoraProjection( mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
) )
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
mlp_module.fc2 = PatchedLoraProjection( mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
) )
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment