Commit 321d74ff authored by Casper Hansen's avatar Casper Hansen
Browse files

Set full quant as default. Always load unquantized with AutoModel instead of accelerate.

parent bd899094
...@@ -81,7 +81,7 @@ if __name__ == '__main__': ...@@ -81,7 +81,7 @@ if __name__ == '__main__':
python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq
- Run perplexity of quantized model: - Run perplexity of quantized model:
python -m awq.entry --entry_type eval --model_path casperhansen/vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model: - Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
......
...@@ -38,7 +38,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -38,7 +38,7 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"): calib_data="pileval"):
self.quant_config = quant_config self.quant_config = quant_config
...@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
def from_quantized(self, model_path, model_type, model_filename, def from_quantized(self, model_path, model_type, model_filename,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True): safetensors=False, is_quantized=True):
# Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors: if safetensors:
...@@ -263,7 +263,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -263,7 +263,8 @@ class BaseAWQForCausalLM(nn.Module):
# TODO: Better naming, model_filename becomes a directory # TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}' model_filename = model_path + f'/{model_filename}'
# Load config # [STEP 2] Load config
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json' quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path): if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file: with open(quant_config_path, 'r') as file:
...@@ -274,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -274,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights # [STEP 3] Load model
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
...@@ -286,12 +287,11 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -286,12 +287,11 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights() model.tie_weights()
# Load model weights # Load model weights
try: if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type]) model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
except Exception as ex:
# Fallback to auto model if load_checkpoint_and_dispatch is not working
print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')
else:
# If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map( device_map = infer_auto_device_map(
model, model,
no_split_module_classes=[self.layer_type], no_split_module_classes=[self.layer_type],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment