Commit 4f76ecb2 authored by s4rduk4r's avatar s4rduk4r
Browse files

Offload to cpu

parent 8eb26eb2
...@@ -40,11 +40,13 @@ class AutoAWQForCausalLM: ...@@ -40,11 +40,13 @@ class AutoAWQForCausalLM:
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True, trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM: batch_size=1, safetensors=False,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder
) )
...@@ -133,7 +133,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -133,7 +133,8 @@ class BaseAWQForCausalLM(nn.Module):
def from_quantized(self, model_path, model_type, model_filename='', def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16, max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True, trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'): fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs # [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version, self, model_path, model_filename, safetensors, version,
...@@ -153,6 +154,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -153,6 +154,7 @@ class BaseAWQForCausalLM(nn.Module):
device_map = infer_auto_device_map( device_map = infer_auto_device_map(
model, model,
no_split_module_classes=[self.layer_type], no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype dtype=torch_dtype
) )
...@@ -160,14 +162,31 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -160,14 +162,31 @@ class BaseAWQForCausalLM(nn.Module):
load_checkpoint_in_model( load_checkpoint_in_model(
model, model,
checkpoint=model_weights_path, checkpoint=model_weights_path,
device_map=device_map device_map=device_map,
offload_folder=offload_folder,
dtype=torch_dtype
) )
# Dispath to devices # Dispath to devices
if max_memory is None:
# VRAM only
model = simple_dispatch_model(model, device_map) model = simple_dispatch_model(model, device_map)
if fuse_layers: if fuse_layers:
self.fuse_layers(model, quant_config) self.fuse_layers(model, quant_config)
else:
if fuse_layers:
self.fuse_layers(model, quant_config)
# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
# offload_buffers=offload_folder is not None,
offload_dir=offload_folder
)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment