"tests/test_modeling_deberta.py" did not exist on "2a667b1eb979a406642eb03fd234211b0bf0aa41"
Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
from .modeling import AutoGPTQForCausalLM, BaseQuantizeConfig
from .utils.exllama_utils import exllama_set_max_input_length
from .utils.peft_utils import get_gptq_peft_model
__version__ = "0.8.0.dev0"
from .language_modeling_task import LanguageModelingTask
from .sequence_classification_task import SequenceClassificationTask, get_predictions
from .text_summarization_task import TextSummarizationTask
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..modeling import BaseGPTQForCausalLM
from ..utils.data_utils import get_dataloader
class BaseTask:
def __init__(
self,
model: Union[BaseGPTQForCausalLM, PreTrainedModel],
tokenizer: PreTrainedTokenizer,
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs,
):
self.model = model
self.tokenizer = tokenizer
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.model.config.pad_token_id = self.tokenizer.eos_token_id
self.dl = get_dataloader(
data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
tokenizer=tokenizer,
**kwargs,
)
self.device = device
if not self.device:
self.device = self.model.device
if isinstance(self.device, str):
self.device = torch.device(self.device)
@abstractmethod
def _predict(self, batch_data: Dict[str, Any], **kwargs) -> List[Any]:
pass
@abstractmethod
def _parse_labels(self, label_ids: torch.LongTensor) -> List[Any]:
pass
@abstractmethod
def _metric(self, pred: List[Any], label: List[Any]) -> Dict[str, float]:
pass
def run(self, **predict_kwargs) -> Dict[str, float]:
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
predictions = []
labels = []
for batch_data in self.dl:
for k, v in batch_data.items():
if isinstance(v, torch.Tensor):
batch_data[k] = v.to(self.device)
labels += self._parse_labels(batch_data["labels"])
predictions += self._predict(batch_data, **predict_kwargs)
return self._metric(predictions, labels)
import sys
from typing import List, Sequence
import numpy as np
def levenshtein_distance(seq1: Sequence, seq2: Sequence):
if seq1 == seq2:
return 0
num_rows = len(seq1) + 1
num_cols = len(seq2) + 1
dp_matrix = np.empty((num_rows, num_cols))
dp_matrix[0, :] = range(num_cols)
dp_matrix[:, 0] = range(num_rows)
for i in range(1, num_rows):
for j in range(1, num_cols):
if seq1[i - 1] == seq2[j - 1]:
dp_matrix[i, j] = dp_matrix[i - 1, j - 1]
else:
dp_matrix[i, j] = (
min(
dp_matrix[i - 1, j - 1],
dp_matrix[i - 1, j],
dp_matrix[i, j - 1],
)
+ 1
)
return dp_matrix[num_rows - 1, num_cols - 1]
def get_closest_label(pred: Sequence, classes: List[Sequence]) -> int:
min_id = sys.maxsize
min_edit_distance = sys.maxsize
for i, class_label in enumerate(classes):
edit_distance = levenshtein_distance(pred, class_label)
if edit_distance < min_edit_distance:
min_id = i
min_edit_distance = edit_distance
return min_id
__all__ = ["levenshtein_distance", "get_closest_label"]
from typing import List, Optional, Union
from torch import LongTensor
from transformers import PreTrainedTokenizer
def postprocess_generation_ids(
input_ids: LongTensor,
output_ids: LongTensor,
num_return_sequences: int,
tokenizer: Optional[PreTrainedTokenizer] = None,
pad_token_ids: Optional[int] = None,
) -> List[List[Union[str, List[int]]]]:
outputs = []
for idx, start in enumerate(range(0, len(output_ids), num_return_sequences)):
sub_output_ids = output_ids[start : start + num_return_sequences]
sub_generated_ids = sub_output_ids[..., input_ids[idx].size(0) :]
if tokenizer:
decoded_bach = (
generated_text
for generated_text in tokenizer.batch_decode(sub_generated_ids, clean_up_tokenization_spaces=True)
)
decoded_bach = list(decoded_bach)
outputs.append(decoded_bach)
else:
sub_generated_ids = sub_output_ids.cpu().numpy().tolist()
for i, one_sub_generated_ids in enumerate(sub_generated_ids):
if pad_token_ids is not None and pad_token_ids in one_sub_generated_ids:
one_sub_generated_ids = one_sub_generated_ids[: one_sub_generated_ids.index(pad_token_ids)]
sub_generated_ids[i] = one_sub_generated_ids
outputs.append(sub_generated_ids)
return outputs
__all__ = ["postprocess_generation_ids"]
import math
from typing import Any, Dict, List, Optional
from torch import LongTensor
from ._base import BaseTask
class LanguageModelingTask(BaseTask):
def __init__(
self,
model,
tokenizer,
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs,
):
kwargs["merge_prompt_label"] = True
super().__init__(
model=model,
tokenizer=tokenizer,
data_name_or_path=data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
device=device,
**kwargs,
)
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[float]:
outputs = self.model(**batch_data)
loss = outputs.loss.cpu().item()
return [loss]
def _parse_labels(self, label_ids: LongTensor) -> List[Any]:
return []
def _metric(self, pred: List[Any], label: List[Any]) -> Dict[str, float]:
return {"ppl": math.exp(sum(pred) / len(pred))}
def run(self) -> Dict[str, float]:
return super().run()
__all__ = ["LanguageModelingTask"]
from collections import Counter
from typing import Any, Dict, List, Optional
import numpy as np
from torch import LongTensor
from transformers import GenerationConfig, PreTrainedTokenizer
from ._base import BaseTask
from ._utils.classification_utils import get_closest_label
from ._utils.generation_utils import postprocess_generation_ids
def get_predictions(
input_ids: LongTensor,
output_ids: LongTensor,
num_return_sequences: int,
tokenizer: PreTrainedTokenizer,
classes: List[str],
) -> List[int]:
predictions = []
generated_texts = postprocess_generation_ids(
input_ids=input_ids,
output_ids=output_ids,
num_return_sequences=num_return_sequences,
tokenizer=tokenizer,
)
for sub_generated_texts in generated_texts:
sub_predictions = []
for gen_text in sub_generated_texts:
sub_predictions.append(get_closest_label(gen_text.lower().strip(), classes))
predictions.append(Counter(sub_predictions).most_common(1)[0][0])
return predictions
class SequenceClassificationTask(BaseTask):
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer,
classes: List[str],
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs,
):
kwargs["merge_prompt_label"] = False
super().__init__(
model=model,
tokenizer=tokenizer,
data_name_or_path=data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
device=device,
**kwargs,
)
self.classes = [each.lower().strip() for each in classes]
classes_ids = self.tokenizer(classes)
self.max_new_tokens = max([len(each) for each in classes_ids])
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[int]:
generation_config = kwargs["generation_config"]
output_ids = self.model.generate(
input_ids=batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
generation_config=generation_config,
)
return get_predictions(
batch_data["input_ids"],
output_ids,
generation_config.num_return_sequences,
self.tokenizer,
self.classes,
)
def _parse_labels(self, label_ids: LongTensor) -> List[int]:
labels = []
for one_label_ids in label_ids:
one_label_ids = one_label_ids[(one_label_ids == -100).sum() :]
label = self.tokenizer.decode(one_label_ids, clean_up_tokenization_spaces=True).lower().strip()
label = get_closest_label(label, self.classes)
labels.append(label)
return labels
def _metric(self, pred: List[int], label: List[int]) -> Dict[str, float]:
pred = np.array(pred)
label = np.array(label)
acc = (pred == label).mean()
return {"acc": acc}
def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]:
if not generation_config:
generation_config = GenerationConfig(num_beams=1, do_sample=False, num_return_sequences=1)
generation_config.max_new_tokens = self.max_new_tokens
generation_config.eos_token_id = self.tokenizer.eos_token_id
generation_config.pad_token_id = self.tokenizer.pad_token_id
return super().run(generation_config=generation_config)
__all__ = ["SequenceClassificationTask"]
from typing import Any, Dict, List, Optional
import rouge
from torch import LongTensor
from transformers import GenerationConfig
from ._base import BaseTask
from ._utils.generation_utils import postprocess_generation_ids
class TextSummarizationTask(BaseTask):
def __init__(
self,
model,
tokenizer,
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs,
):
kwargs["merge_prompt_label"] = False
super().__init__(
model=model,
tokenizer=tokenizer,
data_name_or_path=data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
device=device,
**kwargs,
)
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[str]:
generation_config = kwargs["generation_config"]
output_ids = self.model.generate(
input_ids=batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
generation_config=generation_config,
)
return [
each[0].lower().strip()
for each in postprocess_generation_ids(
input_ids=batch_data["input_ids"],
output_ids=output_ids,
num_return_sequences=generation_config.num_return_sequences,
tokenizer=self.tokenizer,
)
]
def _parse_labels(self, label_ids: LongTensor) -> List[str]:
labels = []
for one_label_ids in label_ids:
one_label_ids = one_label_ids[(one_label_ids == -100).sum() :]
label = self.tokenizer.decode(one_label_ids).lower().strip()
labels.append(label)
return labels
def _metric(self, pred: List[Any], label: List[Any]) -> Dict[str, Dict[str, float]]:
metric = rouge.Rouge()
return metric.get_scores(hyps=pred, refs=label, avg=True)
def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]:
if not generation_config:
generation_config = GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=128)
generation_config.num_return_sequences = 1
generation_config.eos_token_id = self.tokenizer.eos_token_id
generation_config.pad_token_id = self.tokenizer.pad_token_id
return super().run(generation_config=generation_config)
__all__ = ["TextSummarizationTask"]
from ._base import BaseGPTQForCausalLM, BaseQuantizeConfig
from .auto import GPTQ_CAUSAL_LM_MODEL_MAP, AutoGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .longllama import LongLlamaGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mixtral import MixtralGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .phi import PhiGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
from .qwen2 import Qwen2GPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .stablelmepoch import StableLMEpochGPTQForCausalLM
from .starcoder2 import Starcoder2GPTQForCausalLM
from .xverse import XverseGPTQForCausalLM
from .yi import YiGPTQForCausalLM
from .yuan import YuanGPTQForCausalLM
This diff is collapsed.
from torch import device
from ..utils.import_utils import compare_transformers_version
CPU = device("cpu")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = [
"bloom",
"gptj",
"gpt2",
"gpt_neox",
"opt",
"moss",
"gpt_bigcode",
"codegen",
"RefinedWebModel",
"RefinedWeb",
"baichuan",
"internlm",
"qwen",
"xverse",
"deci",
"stablelm_epoch",
"mpt",
"cohere",
"yuan",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.30.0", op="ge"):
SUPPORTED_MODELS.append("longllama")
if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon")
if compare_transformers_version("v4.34.0", op="ge"):
SUPPORTED_MODELS.append("mistral")
SUPPORTED_MODELS.append("Yi")
if compare_transformers_version("v4.36.0", op="ge"):
SUPPORTED_MODELS.append("mixtral")
if compare_transformers_version("v4.37.0", op="ge"):
SUPPORTED_MODELS.append("qwen2")
SUPPORTED_MODELS.append("phi")
if compare_transformers_version("v4.38.0", op="ge"):
SUPPORTED_MODELS.append("gemma")
if compare_transformers_version("v4.39.0.dev0", op="ge"):
SUPPORTED_MODELS.append("starcoder2")
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
__all__ = ["CPU", "CUDA_0", "SUPPORTED_MODELS", "EXLLAMA_DEFAULT_MAX_INPUT_LENGTH"]
This diff is collapsed.
from inspect import signature
from typing import Dict, Optional, Union
from ._base import BaseGPTQForCausalLM, BaseQuantizeConfig
from ._utils import check_and_get_model_type
from .baichuan import BaiChuanGPTQForCausalLM
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .longllama import LongLlamaGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mixtral import MixtralGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .phi import PhiGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
from .qwen2 import Qwen2GPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .stablelmepoch import StableLMEpochGPTQForCausalLM
from .starcoder2 import Starcoder2GPTQForCausalLM
from .xverse import XverseGPTQForCausalLM
from .yi import YiGPTQForCausalLM
from .yuan import YuanGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
"gpt_neox": GPTNeoXGPTQForCausalLM,
"gptj": GPTJGPTQForCausalLM,
"gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM,
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM,
"cohere": CohereGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM,
"RefinedWeb": RWGPTQForCausalLM,
"falcon": RWGPTQForCausalLM,
"baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM,
"mistral": MistralGPTQForCausalLM,
"Yi": YiGPTQForCausalLM,
"xverse": XverseGPTQForCausalLM,
"deci": DeciLMGPTQForCausalLM,
"stablelm_epoch": StableLMEpochGPTQForCausalLM,
"starcoder2": Starcoder2GPTQForCausalLM,
"mixtral": MixtralGPTQForCausalLM,
"qwen2": Qwen2GPTQForCausalLM,
"longllama": LongLlamaGPTQForCausalLM,
"gemma": GemmaGPTQForCausalLM,
"phi": PhiGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
"yuan": YuanGPTQForCausalLM,
}
class AutoGPTQForCausalLM:
def __init__(self):
raise EnvironmentError(
"AutoGPTQModelForCausalLM is designed to be instantiated\n"
"using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n"
"using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model."
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
**model_init_kwargs,
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantize_config=quantize_config,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
**model_init_kwargs,
)
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
inject_fused_attention: bool = False,
inject_fused_mlp: bool = False,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
warmup_triton: bool = False,
trainable: bool = False,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_marlin: bool = False,
use_tritonv2: bool = False,
**kwargs,
) -> BaseGPTQForCausalLM:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
if disable_exllama is None:
if disable_exllamav2:
disable_exllama = False
else:
disable_exllama = True
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
# A static list of kwargs needed for huggingface_hub
huggingface_kwargs = [
"cache_dir",
"force_download",
"proxies",
"resume_download",
"local_files_only",
"use_auth_token",
"revision",
"subfolder",
"_raise_exceptions_for_missing_entries",
"_commit_hash",
]
# TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs?
keywords = {
key: kwargs[key]
for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs
if key in kwargs
}
return quant_func(
model_name_or_path=model_name_or_path,
device_map=device_map,
max_memory=max_memory,
device=device,
low_cpu_mem_usage=low_cpu_mem_usage,
use_triton=use_triton,
inject_fused_attention=inject_fused_attention,
inject_fused_mlp=inject_fused_mlp,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=warmup_triton,
trainable=trainable,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2,
**keywords,
)
__all__ = ["AutoGPTQForCausalLM"]
from ._base import BaseGPTQForCausalLM
class BaiChuanGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "DecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.W_pack"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["BaiChuanGPTQForCausalLM"]
from ._base import BaseGPTQForCausalLM
class BloomGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "BloomBlock"
layers_block_name = "transformer.h"
outside_layer_modules = [
"transformer.word_embeddings",
"transformer.word_embeddings_layernorm",
"transformer.ln_f",
]
inside_layer_modules = [
["self_attention.query_key_value"],
["self_attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"],
]
__all__ = ["BloomGPTQForCausalLM"]
from ._base import BaseGPTQForCausalLM
class CodeGenGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "CodeGenBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [
["attn.qkv_proj"],
["attn.out_proj"],
["mlp.fc_in"],
["mlp.fc_out"],
]
__all__ = ["CodeGenGPTQForCausalLM"]
from logging import getLogger
from ._base import BaseGPTQForCausalLM
logger = getLogger(__name__)
class CohereGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "CohereDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["CohereGPTQForCausalLM"]
\ No newline at end of file
from transformers.configuration_utils import PretrainedConfig
class YuanConfig(PretrainedConfig):
model_type = "yuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=135040,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
hidden_act="silu",
model_max_length=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=77185,
bos_token_id=77185,
eos_token_id=77185,
tie_word_embeddings=True,
**kwargs,
):
self.vocab_size = vocab_size
self.model_max_length = model_max_length
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
from logging import getLogger
from ..utils.import_utils import compare_transformers_version
from ._base import BaseGPTQForCausalLM
if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None
logger = getLogger(__name__)
class DeciLMGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "DeciLMDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
__all__ = ["DeciLMGPTQForCausalLM"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment