"tests/vscode:/vscode.git/clone" did not exist on "84371daf75507c849a38a9a44b2fb2af89e96dd3"
Unverified Commit 9c749713 authored by Lucas Tucker's avatar Lucas Tucker Committed by GitHub
Browse files

[mypy] Forward pass function type hints in lora (#11740)


Signed-off-by: default avatarlucast2021 <lucast2021@headroyce.org>
Co-authored-by: default avatarlucast2021 <lucast2021@headroyce.org>
parent 022c5c69
...@@ -405,7 +405,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -405,7 +405,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
self.output_size = self.base_layer.output_size self.output_size = self.base_layer.output_size
self.n_slices = 1 self.n_slices = 1
def forward(self, input_): def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ReplicatedLinearWithLoRA """Forward of ReplicatedLinearWithLoRA
Args: Args:
...@@ -496,7 +498,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -496,7 +498,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
bias = bias[start_idx:end_idx] bias = bias[start_idx:end_idx]
return bias return bias
def forward(self, input_): def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
Args: Args:
...@@ -833,7 +837,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): ...@@ -833,7 +837,9 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias return bias
def forward(self, input_): def forward(
self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Forward of RowParallelLinear """Forward of RowParallelLinear
Args: Args:
......
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -219,6 +219,7 @@ class LoRAModel(AdapterModel): ...@@ -219,6 +219,7 @@ class LoRAModel(AdapterModel):
config["vllm_max_position_embeddings"] = max_position_embeddings config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config) peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]]
if os.path.isfile(lora_tensor_path): if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {} tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules. # Find unexpected modules.
......
...@@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase): ...@@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase):
assert param.size() == loaded_weight.size() assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(
self, x: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment