Unverified Commit 6b5dc29f authored by Casper's avatar Casper Committed by GitHub
Browse files

Load on CPU to avoid OOM (#236)

parent 5eb1d2f0
...@@ -115,26 +115,6 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -115,26 +115,6 @@ class BaseAWQForCausalLM(nn.Module):
self, model_path, '', safetensors, trust_remote_code=trust_remote_code self, model_path, '', safetensors, trust_remote_code=trust_remote_code
) )
if device_map is None:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Evenly distribute memory on GPUs
max_memory = get_balanced_memory(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Get device map
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_weights_path, model_weights_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment