Commit 321d74ff authored by Casper Hansen's avatar Casper Hansen
Browse files

Set full quant as default. Always load unquantized with AutoModel instead of accelerate.

parent bd899094
......@@ -81,7 +81,7 @@ if __name__ == '__main__':
python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq
- Run perplexity of quantized model:
python -m awq.entry --entry_type eval --model_path casperhansen/vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
......
......@@ -38,7 +38,7 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"):
self.quant_config = quant_config
......@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
def from_quantized(self, model_path, model_type, model_filename,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
# Download model if path is not a directory
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors:
......@@ -263,7 +263,8 @@ class BaseAWQForCausalLM(nn.Module):
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
# Load config
# [STEP 2] Load config
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
......@@ -274,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights
# [STEP 3] Load model
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
......@@ -286,12 +287,11 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights()
# Load model weights
try:
if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
except Exception as ex:
# Fallback to auto model if load_checkpoint_and_dispatch is not working
print(f'{ex} - falling back to AutoModelForCausalLM.from_pretrained')
else:
# If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment