Commit d9bab50c authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement loading correct sequence length based on config + custom max_new_tokens

parent b53a9be2
...@@ -32,10 +32,10 @@ class AutoAWQForCausalLM: ...@@ -32,10 +32,10 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM: device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, device, trust_remote_code=trust_remote_code quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code
) )
\ No newline at end of file
...@@ -239,6 +239,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -239,6 +239,7 @@ class BaseAWQForCausalLM(nn.Module):
model_path, model_path,
model_type, model_type,
model_filename='', model_filename='',
max_new_tokens=None,
device='balanced', device='balanced',
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -247,7 +248,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -247,7 +248,7 @@ class BaseAWQForCausalLM(nn.Module):
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True): safetensors=False, is_quantized=True):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
...@@ -263,7 +264,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -263,7 +264,7 @@ class BaseAWQForCausalLM(nn.Module):
# TODO: Better naming, model_filename becomes a directory # TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}' model_filename = model_path + f'/{model_filename}'
# [STEP 2] Load config # [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class # TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json' quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path): if os.path.exists(quant_config_path):
...@@ -273,8 +274,15 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -273,8 +274,15 @@ class BaseAWQForCausalLM(nn.Module):
# Default config that works for most models # Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4} quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = max_new_tokens
# [STEP 3] Load model # [STEP 3] Load model
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
......
...@@ -3,6 +3,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaFor ...@@ -3,6 +3,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaFor
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer" layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def get_model_layers(model: LlamaForCausalLM): def get_model_layers(model: LlamaForCausalLM):
......
...@@ -2,6 +2,7 @@ from .base import BaseAWQForCausalLM ...@@ -2,6 +2,7 @@ from .base import BaseAWQForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model):
......
...@@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer ...@@ -3,6 +3,7 @@ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM): class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer" layer_type = "OPTDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def get_model_layers(model: OPTForCausalLM): def get_model_layers(model: OPTForCausalLM):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment