Unverified Commit f220ccf1 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #77 from s4rduk4r/offloading

Offloading to cpu and disk
parents 8793a9f7 841a2313
......@@ -41,11 +41,13 @@ class AutoAWQForCausalLM:
@classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
batch_size=1, safetensors=False,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors
)
\ No newline at end of file
fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder
)
......@@ -133,7 +133,8 @@ class BaseAWQForCausalLM(nn.Module):
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
......@@ -153,6 +154,7 @@ class BaseAWQForCausalLM(nn.Module):
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)
......@@ -160,15 +162,24 @@ class BaseAWQForCausalLM(nn.Module):
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map
device_map=device_map,
offload_folder=offload_folder,
dtype=torch_dtype
)
# Dispath to devices
model = simple_dispatch_model(model, device_map)
if fuse_layers:
self.fuse_layers(model, quant_config)
# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=offload_folder
)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=False,
......
......@@ -36,7 +36,7 @@ def apply_rotary_emb(
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment