Commit d9bab50c authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement loading correct sequence length based on config + custom max_new_tokens

parent b53a9be2
......@@ -32,10 +32,10 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, quant_path, quant_filename,
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, device, trust_remote_code=trust_remote_code
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code
)
\ No newline at end of file
......@@ -239,6 +239,7 @@ class BaseAWQForCausalLM(nn.Module):
model_path,
model_type,
model_filename='',
max_new_tokens=None,
device='balanced',
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
......@@ -247,7 +248,7 @@ class BaseAWQForCausalLM(nn.Module):
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename,
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
# [STEP 1] Download model if path is not a directory
......@@ -263,7 +264,7 @@ class BaseAWQForCausalLM(nn.Module):
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
# [STEP 2] Load config
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
......@@ -273,7 +274,14 @@ class BaseAWQForCausalLM(nn.Module):
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = max_new_tokens
# [STEP 3] Load model
with init_empty_weights():
......
......@@ -3,6 +3,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaFor
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
......
......@@ -2,6 +2,7 @@ from .base import BaseAWQForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"
@staticmethod
def get_model_layers(model):
......
......@@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def get_model_layers(model: OPTForCausalLM):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment