Commit e35a9e99 authored by zhuwenwen's avatar zhuwenwen
Browse files

update linear to matmul

parent 7e5aa0c5
......@@ -268,7 +268,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
if self.use_llama_nn:
# if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
# layer.weight = layer.weight[:,:-32]
if bias is not None:
if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight)
......@@ -277,8 +276,16 @@ class UnquantizedLinearMethod(LinearMethodBase):
else:
return torch.matmul(x, layer.weight)
else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
# if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
# return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
if envs.VLLM_USE_NN:
if bias is not None:
if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight)
else:
return torch.matmul(x, layer.weight) + bias
else:
return torch.matmul(x, layer.weight)
else:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
......
......@@ -75,10 +75,10 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
else:
return torch.matmul(x, layer.weight)
else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
else:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
# if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
# return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
# else:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment