Commit 4f76ecb2 authored by s4rduk4r's avatar s4rduk4r
Browse files

Offload to cpu

parent 8eb26eb2
......@@ -40,11 +40,13 @@ class AutoAWQForCausalLM:
@classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
batch_size=1, safetensors=False,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors
)
\ No newline at end of file
fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder
)
......@@ -133,7 +133,8 @@ class BaseAWQForCausalLM(nn.Module):
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
......@@ -153,21 +154,39 @@ class BaseAWQForCausalLM(nn.Module):
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)
)
# Load checkpoint
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map
device_map=device_map,
offload_folder=offload_folder,
dtype=torch_dtype
)
# Dispath to devices
model = simple_dispatch_model(model, device_map)
if max_memory is None:
# VRAM only
model = simple_dispatch_model(model, device_map)
if fuse_layers:
self.fuse_layers(model, quant_config)
else:
if fuse_layers:
self.fuse_layers(model, quant_config)
# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
# offload_buffers=offload_folder is not None,
offload_dir=offload_folder
)
if fuse_layers:
self.fuse_layers(model, quant_config)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment