"docs/vscode:/vscode.git/clone" did not exist on "d0697cc7b6825da0ba92aff93b05ea85b4725018"
Commit e35a9e99 authored by zhuwenwen's avatar zhuwenwen
Browse files

update linear to matmul

parent 7e5aa0c5
...@@ -268,7 +268,6 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -268,7 +268,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
if self.use_llama_nn: if self.use_llama_nn:
# if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32): # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
# layer.weight = layer.weight[:,:-32] # layer.weight = layer.weight[:,:-32]
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight) return torch.addmm(bias, x, layer.weight)
...@@ -277,8 +276,16 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -277,8 +276,16 @@ class UnquantizedLinearMethod(LinearMethodBase):
else: else:
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: # if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias) # return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
if envs.VLLM_USE_NN:
if bias is not None:
if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight)
else:
return torch.matmul(x, layer.weight) + bias
else:
return torch.matmul(x, layer.weight)
else: else:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
...@@ -75,9 +75,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -75,9 +75,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
else: else:
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: # if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias) # return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
else: # else:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment