Commit 7e5aa0c5 authored by zhuwenwen's avatar zhuwenwen
Browse files

update linear apply

parent b35a518a
...@@ -266,8 +266,8 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -266,8 +266,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn:
if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1': # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
layer.weight = layer.weight[:,:-32] # layer.weight = layer.weight[:,:-32]
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
...@@ -278,9 +278,9 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -278,9 +278,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias) return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
else: else:
return dispatch_unquantized_gemm()(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(CustomOp): class LinearBase(CustomOp):
......
...@@ -187,10 +187,10 @@ def cpu_unquantized_gemm(layer: torch.nn.Module, ...@@ -187,10 +187,10 @@ def cpu_unquantized_gemm(layer: torch.nn.Module,
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
if current_platform.is_rocm(): # if current_platform.is_rocm():
# return rocm_unquantized_gemm # return rocm_unquantized_gemm
return torch.nn.functional.linear # return torch.nn.functional.linear
elif current_platform.is_cpu(): if current_platform.is_cpu():
return cpu_unquantized_gemm return cpu_unquantized_gemm
else: else:
return default_unquantized_gemm return default_unquantized_gemm
...@@ -76,9 +76,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -76,9 +76,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias) return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
else: else:
return dispatch_unquantized_gemm()(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment