Commit 2dada8f8 authored by Casper Hansen's avatar Casper Hansen
Browse files

Improve model loading

parent 72f954ce
...@@ -27,21 +27,23 @@ class AutoAWQForCausalLM: ...@@ -27,21 +27,23 @@ class AutoAWQForCausalLM:
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained') 'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod @classmethod
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False) -> BaseAWQForCausalLM: def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code) model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
device_map=device_map, **model_init_kwargs
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True, trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM: batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors fuse_layers=fuse_layers, safetensors=safetensors
) )
\ No newline at end of file
...@@ -97,24 +97,83 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -97,24 +97,83 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True, safetensors=False): trust_remote_code=True, safetensors=False, device_map=None,
return self.from_quantized( **model_init_kwargs):
model_path, # Get weights path and quant config
model_type, model_weights_path, config, quant_config = self._load_config(
model_filename='', self, model_path, '', safetensors, trust_remote_code=trust_remote_code
max_new_tokens=None,
device='balanced',
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
is_quantized=False
) )
if device_map is None:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Get device map
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# If not quantized, must load with AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_weights_path,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
use_safetensors=safetensors,
low_cpu_mem_usage=True,
**model_init_kwargs
)
model.eval()
return self(model, model_type, is_quantized=False, quant_config=quant_config)
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename='', def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, device='balanced', torch_dtype=torch.float16, max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True, trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'): fuse_layers=False, version='GEMM'):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
trust_remote_code, max_new_tokens=max_new_tokens
)
# [STEP 3] Load model
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights()
# Get device map
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Load checkpoint
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map
)
# Dispath to devices
model = simple_dispatch_model(model, device_map)
if fuse_layers:
self.fuse_layers(model, quant_config)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=False,
version="GEMM", trust_remote_code=True, max_new_tokens=4096):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
...@@ -152,53 +211,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -152,53 +211,7 @@ class BaseAWQForCausalLM(nn.Module):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = max_new_tokens config.max_new_tokens = max_new_tokens
# [STEP 3] Load model return model_weights_path, config, quant_config
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights()
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Load model weights
if is_quantized:
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map
)
model = simple_dispatch_model(model, device_map)
if fuse_layers:
self.fuse_layers(model, quant_config)
else:
# If not quantized, must load with AutoModelForCausalLM
del model
# Load model weights
model = AutoModelForCausalLM.from_pretrained(
model_weights_path,
device_map=device_map,
trust_remote_code=trust_remote_code,
offload_folder="offload",
offload_state_dict=True,
torch_dtype=torch_dtype,
use_safetensors=safetensors
)
model.eval()
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, quant_config, version): def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights # Real quantization of weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment