Unverified Commit 84c87877 authored by Casper's avatar Casper Committed by GitHub
Browse files

Add config to Base model (#207)

parent 0e77dbc1
......@@ -13,7 +13,12 @@ from awq.quantize.quantizer import AwqQuantizer
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from transformers import (
AutoModelForCausalLM,
AutoConfig,
PreTrainedModel,
PretrainedConfig,
)
from accelerate.big_modeling import (
init_empty_weights,
infer_auto_device_map,
......@@ -22,12 +27,13 @@ from accelerate.big_modeling import (
from accelerate.utils import get_balanced_memory
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config):
def __init__(self, model, model_type, is_quantized, config, quant_config):
super().__init__()
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config
def to(self, device: str):
......@@ -141,7 +147,7 @@ class BaseAWQForCausalLM(nn.Module):
model.eval()
return self(model, model_type, is_quantized=False, quant_config=quant_config)
return self(model, model_type, is_quantized=False, config=config, quant_config=quant_config)
@classmethod
def from_quantized(self, model_path, model_type, model_filename='',
......@@ -181,7 +187,7 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers:
self.fuse_layers(model)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
return self(model, model_type, is_quantized=is_quantized, config=config, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=True,
version="GEMM", trust_remote_code=True, max_new_tokens=4096,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment