Unverified Commit 9c749713 authored by Lucas Tucker's avatar Lucas Tucker Committed by GitHub
Browse files

[mypy] Forward pass function type hints in lora (#11740)


Signed-off-by: default avatarlucast2021 <lucast2021@headroyce.org>
Co-authored-by: default avatarlucast2021 <lucast2021@headroyce.org>
parent 022c5c69
......@@ -405,7 +405,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
self.output_size = self.base_layer.output_size
self.n_slices = 1
def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ReplicatedLinearWithLoRA
Args:
......@@ -496,7 +498,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
bias = bias[start_idx:end_idx]
return bias
def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ColumnParallelLinear
Args:
......@@ -833,7 +837,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias
def forward(self, input_):
def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of RowParallelLinear
Args:
......
......@@ -4,7 +4,7 @@ import math
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import safetensors.torch
import torch
......@@ -219,6 +219,7 @@ class LoRAModel(AdapterModel):
config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]]
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
......
......@@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment