Commit c5eae25b authored by xuxzh1's avatar xuxzh1 🎱
Browse files

Replace all the interfaces of Triton with the implementation of ExLlamaV2

parent bb9e670a
...@@ -171,7 +171,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -171,7 +171,7 @@ class GPTQWeightsLoader(WeightsLoader):
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_exllama=use_exllama, use_exllama=True,
) )
def get_weights_col_packed( def get_weights_col_packed(
...@@ -227,7 +227,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -227,7 +227,7 @@ class GPTQWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq", use_awq_kernel=self.quantize == "awq",
use_exllama=False, use_exllama=True,
) )
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
...@@ -294,7 +294,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -294,7 +294,7 @@ class GPTQWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq", use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama, use_exllama=True,
) )
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
...@@ -394,7 +394,7 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -394,7 +394,7 @@ class GPTQWeightsLoader(WeightsLoader):
bits=self.bits, bits=self.bits,
groupsize=self.groupsize, groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq", use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama, use_exllama=True,
) )
def _get_gptq_params(self, weights: Weights): def _get_gptq_params(self, weights: Weights):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment