Unverified Commit 7c976752 authored by Andrey Glushenkov's avatar Andrey Glushenkov Committed by GitHub
Browse files

Pass arguments to AutoConfig (#97)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent c5581b27
...@@ -45,12 +45,13 @@ class AutoAWQForCausalLM: ...@@ -45,12 +45,13 @@ class AutoAWQForCausalLM:
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True, trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=True, batch_size=1, safetensors=True,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM: max_memory=None, offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors, fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder max_memory=max_memory, offload_folder=offload_folder,
**config_kwargs
) )
...@@ -135,11 +135,13 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -135,11 +135,13 @@ class BaseAWQForCausalLM(nn.Module):
max_new_tokens=None, torch_dtype=torch.float16, max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=True, is_quantized=True, trust_remote_code=True, safetensors=True, is_quantized=True,
fuse_layers=False, version='GEMM', fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None): max_memory=None, offload_folder=None,
**config_kwargs):
# [STEP 1-2] Load weights path and configs # [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version, self, model_path, model_filename, safetensors, version,
trust_remote_code, max_new_tokens=max_new_tokens trust_remote_code, max_new_tokens=max_new_tokens,
**config_kwargs
) )
# [STEP 3] Load model # [STEP 3] Load model
...@@ -184,7 +186,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -184,7 +186,8 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=True, def _load_config(self, model_path, model_filename, safetensors=True,
version="GEMM", trust_remote_code=True, max_new_tokens=4096): version="GEMM", trust_remote_code=True, max_new_tokens=4096,
**config_kwargs):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"] ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
...@@ -206,11 +209,11 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -206,11 +209,11 @@ class BaseAWQForCausalLM(nn.Module):
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
config.max_new_tokens = getattr(config, self.max_new_tokens_key) config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else: else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
config.max_new_tokens = max_new_tokens config.max_new_tokens = max_new_tokens
return model_weights_path, config, quant_config return model_weights_path, config, quant_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment