Commit d430694b authored by Casper Hansen's avatar Casper Hansen
Browse files

save_quantized working in all cases, from_quantized adapted.

parent af4e0622
......@@ -29,6 +29,9 @@ def run_search(model_path, dump_path, w_bit, q_config):
# Save search results
model.save_quantized(dump_path)
# Save tokenizer
tokenizer.save_pretrained(dump_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config):
"""
Step 2/2: Use the search results to quantize model weights
......@@ -43,16 +46,16 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
# Save quantized model
model.save_quantized(dump_path)
def run_perplexity(model_path, quant_path, w_bit, q_config, device):
def run_perplexity(quant_path, quant_file, w_bit, q_config, device):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, w_bit, q_config)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
# Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
lm_eval_model = LMEvalAdaptor(quant_path, model, tokenizer, device, batch_size=1)
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
......@@ -68,15 +71,16 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device):
if __name__ == '__main__':
"""
python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/pytorch_model.bin --quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --model_path mosaicml/mpt-7b-8k-chat --quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/awq_model_search_result.pt --quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --quant_path mpt-7b-8k-chat-awq --quant_file awq_model_w4_g128.pt
"""
parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model')
parser.add_argument('--device', type=str, default='balanced', help='Device to load model to')
parser.add_argument('--quant_path', type=str, help='Path to AWQ model directory')
parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128)
args = parser.parse_args()
......@@ -88,6 +92,6 @@ if __name__ == '__main__':
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
elif args.entry_type == 'perplexity':
run_perplexity(args.model_path, args.quant_path, args.w_bit, q_config, args.device)
run_perplexity(args.quant_path, args.quant_file, args.w_bit, q_config, args.device)
else:
raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
......@@ -29,11 +29,11 @@ class AutoAWQForCausalLM:
)
@classmethod
def from_quantized(self, model_path, quant_file, w_bit=4, q_config={},
def from_quantized(self, quant_path, quant_filename, w_bit=4, q_config={},
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
q_config = q_config if q_config else self.default_q_config
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
model_path, model_type, quant_file, w_bit, q_config, device, trust_remote_code=trust_remote_code
quant_path, model_type, quant_filename, w_bit, q_config, device, trust_remote_code=trust_remote_code
)
\ No newline at end of file
......@@ -23,6 +23,12 @@ class BaseAWQForCausalLM:
self.is_quantized:bool = is_quantized
self.search_result = None
def to(self, device: str):
return self.model.to(device)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
@torch.no_grad()
def quantize(self, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
......@@ -170,19 +176,29 @@ class BaseAWQForCausalLM:
return awq_results
def save_quantized(self, save_dir):
def _save_files(save_dir, model_name, model):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model fiels without search results
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty module
os.remove(f'{save_dir}/pytorch_model.bin')
# Save search results
torch.save(model, f'{save_dir}/{model_name}')
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
if self.search_result is None:
self.model.save_pretrained(save_dir, state_dict=self.model.state_dict())
model_name = 'awq_model_w4_g128.pt'
_save_files(save_dir, model_name, self.model.state_dict())
else:
self.model.save_pretrained(save_dir, state_dict=self.search_result)
# TODO: Rename model name & save quant_config
if self.search_result is not None:
model_name = 'awq_model_search_result.pt'
else:
model_name = 'awq_model_w4_g128.pt'
_save_files(save_dir, model_name, self.search_result)
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
......@@ -190,7 +206,7 @@ class BaseAWQForCausalLM:
return self.from_quantized(
model_path,
model_type,
quant_file='',
model_filename='',
device='balanced',
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
......@@ -198,11 +214,14 @@ class BaseAWQForCausalLM:
)
@classmethod
def from_quantized(self, model_path, model_type, quant_file, w_bit=4, q_config={},
def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={},
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, is_quantized=True):
# Download model
# Download model if path is not a directory
if not os.path.isdir(model_path):
model_path = snapshot_download(model_path)
quant_path = model_path + f'/{quant_file}' if is_quantized else model_path
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
......@@ -219,7 +238,7 @@ class BaseAWQForCausalLM:
model.tie_weights()
# Load model weights
model = load_checkpoint_and_dispatch(model, quant_path, device_map=device, no_split_module_classes=[self.layer_type])
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
return self(model, model_type, is_quantized=is_quantized)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment