Unverified Commit 12581fbd authored by laoda513's avatar laoda513 Committed by GitHub
Browse files

support max_memory to specify mem usage for each GPU (#460)

parent 33af761a
......@@ -83,6 +83,7 @@ class AutoAWQForCausalLM:
batch_size=1,
safetensors=True,
device_map="balanced",
max_memory=None,
offload_folder=None,
download_kwargs=None,
**config_kwargs,
......@@ -108,6 +109,7 @@ class AutoAWQForCausalLM:
use_exllama_v2=use_exllama_v2,
safetensors=safetensors,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
download_kwargs=download_kwargs,
**config_kwargs,
......
......@@ -393,6 +393,12 @@ class BaseAWQForCausalLM(nn.Module):
"A device map that will be passed onto the model loading method from transformers."
),
] = "balanced",
max_memory: Annotated[
Dict[Union[int, str], Union[int, str]],
Doc(
'A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"'
),
] = None,
offload_folder: Annotated[
str,
Doc("The folder ot offload the model to."),
......@@ -449,6 +455,7 @@ class BaseAWQForCausalLM(nn.Module):
model,
checkpoint=model_weights_path,
device_map=device_map,
max_memory=max_memory,
no_split_module_classes=[self.layer_type],
offload_folder=offload_folder,
dtype=torch_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment