Unverified Commit 87bc0c49 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix ReplicatedLinearWithLoRA (#27065)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent fe3b9372
......@@ -56,3 +56,15 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is ReplicatedLinear
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
return lora_a
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
return lora_b
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment