Commit d430694b authored by Casper Hansen's avatar Casper Hansen
Browse files

save_quantized working in all cases, from_quantized adapted.

parent af4e0622
...@@ -29,6 +29,9 @@ def run_search(model_path, dump_path, w_bit, q_config): ...@@ -29,6 +29,9 @@ def run_search(model_path, dump_path, w_bit, q_config):
# Save search results # Save search results
model.save_quantized(dump_path) model.save_quantized(dump_path)
# Save tokenizer
tokenizer.save_pretrained(dump_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config): def run_quant(model_path, search_path, dump_path, w_bit, q_config):
""" """
Step 2/2: Use the search results to quantize model weights Step 2/2: Use the search results to quantize model weights
...@@ -43,16 +46,16 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config): ...@@ -43,16 +46,16 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
# Save quantized model # Save quantized model
model.save_quantized(dump_path) model.save_quantized(dump_path)
def run_perplexity(model_path, quant_path, w_bit, q_config, device): def run_perplexity(quant_path, quant_file, w_bit, q_config, device):
""" """
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
""" """
# Load model # Load model
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device) model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, w_bit, q_config)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
# Load adapter # Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1) lm_eval_model = LMEvalAdaptor(quant_path, model, tokenizer, device, batch_size=1)
# Evaluate perplexity of quantized model # Evaluate perplexity of quantized model
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
...@@ -68,15 +71,16 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device): ...@@ -68,15 +71,16 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device):
if __name__ == '__main__': if __name__ == '__main__':
""" """
python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/pytorch_model.bin --quant_path mpt-7b-8k-chat-awq python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/awq_model_search_result.pt --quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --model_path mosaicml/mpt-7b-8k-chat --quant_path mpt-7b-8k-chat-awq python -m awq.entry --entry_type perplexity --quant_path mpt-7b-8k-chat-awq --quant_file awq_model_w4_g128.pt
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)') parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
parser.add_argument('--model_path', type=str, help='Path to hf model') parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results') parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model') parser.add_argument('--quant_path', type=str, help='Path to AWQ model directory')
parser.add_argument('--device', type=str, default='balanced', help='Device to load model to') parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument('--w_bit', type=int, default=4) parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128) parser.add_argument('--q_group_size', type=int, default=128)
args = parser.parse_args() args = parser.parse_args()
...@@ -88,6 +92,6 @@ if __name__ == '__main__': ...@@ -88,6 +92,6 @@ if __name__ == '__main__':
elif args.entry_type == 'quant': elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config) run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
elif args.entry_type == 'perplexity': elif args.entry_type == 'perplexity':
run_perplexity(args.model_path, args.quant_path, args.w_bit, q_config, args.device) run_perplexity(args.quant_path, args.quant_file, args.w_bit, q_config, args.device)
else: else:
raise Exception('--entry_type must be one of (search|quant|perplexity)') raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
...@@ -29,11 +29,11 @@ class AutoAWQForCausalLM: ...@@ -29,11 +29,11 @@ class AutoAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, model_path, quant_file, w_bit=4, q_config={}, def from_quantized(self, quant_path, quant_filename, w_bit=4, q_config={},
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM: device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
q_config = q_config if q_config else self.default_q_config q_config = q_config if q_config else self.default_q_config
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
model_path, model_type, quant_file, w_bit, q_config, device, trust_remote_code=trust_remote_code quant_path, model_type, quant_filename, w_bit, q_config, device, trust_remote_code=trust_remote_code
) )
\ No newline at end of file
...@@ -23,6 +23,12 @@ class BaseAWQForCausalLM: ...@@ -23,6 +23,12 @@ class BaseAWQForCausalLM:
self.is_quantized:bool = is_quantized self.is_quantized:bool = is_quantized
self.search_result = None self.search_result = None
def to(self, device: str):
return self.model.to(device)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True, auto_scale=True, mse_range=True, run_search=False, run_quant=True,
...@@ -170,19 +176,29 @@ class BaseAWQForCausalLM: ...@@ -170,19 +176,29 @@ class BaseAWQForCausalLM:
return awq_results return awq_results
def save_quantized(self, save_dir): def save_quantized(self, save_dir):
def _save_files(save_dir, model_name, model):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model fiels without search results
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty module
os.remove(f'{save_dir}/pytorch_model.bin')
# Save search results
torch.save(model, f'{save_dir}/{model_name}')
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model # Save model
if self.search_result is None: if self.search_result is None:
self.model.save_pretrained(save_dir, state_dict=self.model.state_dict()) model_name = 'awq_model_w4_g128.pt'
_save_files(save_dir, model_name, self.model.state_dict())
else: else:
self.model.save_pretrained(save_dir, state_dict=self.search_result)
# TODO: Rename model name & save quant_config
if self.search_result is not None:
model_name = 'awq_model_search_result.pt' model_name = 'awq_model_search_result.pt'
else: _save_files(save_dir, model_name, self.search_result)
model_name = 'awq_model_w4_g128.pt'
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
...@@ -190,7 +206,7 @@ class BaseAWQForCausalLM: ...@@ -190,7 +206,7 @@ class BaseAWQForCausalLM:
return self.from_quantized( return self.from_quantized(
model_path, model_path,
model_type, model_type,
quant_file='', model_filename='',
device='balanced', device='balanced',
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -198,11 +214,14 @@ class BaseAWQForCausalLM: ...@@ -198,11 +214,14 @@ class BaseAWQForCausalLM:
) )
@classmethod @classmethod
def from_quantized(self, model_path, model_type, quant_file, w_bit=4, q_config={}, def from_quantized(self, model_path, model_type, model_filename, w_bit=4, q_config={},
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, is_quantized=True): device='balanced', torch_dtype=torch.float16, trust_remote_code=True, is_quantized=True):
# Download model # Download model if path is not a directory
if not os.path.isdir(model_path):
model_path = snapshot_download(model_path) model_path = snapshot_download(model_path)
quant_path = model_path + f'/{quant_file}' if is_quantized else model_path
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
# Load config # Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
...@@ -219,7 +238,7 @@ class BaseAWQForCausalLM: ...@@ -219,7 +238,7 @@ class BaseAWQForCausalLM:
model.tie_weights() model.tie_weights()
# Load model weights # Load model weights
model = load_checkpoint_and_dispatch(model, quant_path, device_map=device, no_split_module_classes=[self.layer_type]) model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
return self(model, model_type, is_quantized=is_quantized) return self(model, model_type, is_quantized=is_quantized)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment