Unverified Commit b8aacac3 authored by Jee Li's avatar Jee Li Committed by GitHub
Browse files

[Bugfix] Fix LoRA bug (#4032)

parent d04973ad
......@@ -32,14 +32,17 @@ if TYPE_CHECKING:
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
if hasattr(base_layer, "linear_weights") and isinstance(
base_layer.linear_weights, dict):
values = list(base_layer.linear_weights.values())
if len(values) and isinstance(values[0], torch.Tensor):
return values[0].device
raise ValueError(f"Unsupported base layer: {base_layer}")
# GPTQ/AWQ/SqueezeLLM
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _apply_lora(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment