Unverified Commit 12581fbd authored by laoda513's avatar laoda513 Committed by GitHub
Browse files

support max_memory to specify mem usage for each GPU (#460)

parent 33af761a
...@@ -83,6 +83,7 @@ class AutoAWQForCausalLM: ...@@ -83,6 +83,7 @@ class AutoAWQForCausalLM:
batch_size=1, batch_size=1,
safetensors=True, safetensors=True,
device_map="balanced", device_map="balanced",
max_memory=None,
offload_folder=None, offload_folder=None,
download_kwargs=None, download_kwargs=None,
**config_kwargs, **config_kwargs,
...@@ -108,6 +109,7 @@ class AutoAWQForCausalLM: ...@@ -108,6 +109,7 @@ class AutoAWQForCausalLM:
use_exllama_v2=use_exllama_v2, use_exllama_v2=use_exllama_v2,
safetensors=safetensors, safetensors=safetensors,
device_map=device_map, device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder, offload_folder=offload_folder,
download_kwargs=download_kwargs, download_kwargs=download_kwargs,
**config_kwargs, **config_kwargs,
......
...@@ -393,6 +393,12 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -393,6 +393,12 @@ class BaseAWQForCausalLM(nn.Module):
"A device map that will be passed onto the model loading method from transformers." "A device map that will be passed onto the model loading method from transformers."
), ),
] = "balanced", ] = "balanced",
max_memory: Annotated[
Dict[Union[int, str], Union[int, str]],
Doc(
'A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"'
),
] = None,
offload_folder: Annotated[ offload_folder: Annotated[
str, str,
Doc("The folder ot offload the model to."), Doc("The folder ot offload the model to."),
...@@ -449,6 +455,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -449,6 +455,7 @@ class BaseAWQForCausalLM(nn.Module):
model, model,
checkpoint=model_weights_path, checkpoint=model_weights_path,
device_map=device_map, device_map=device_map,
max_memory=max_memory,
no_split_module_classes=[self.layer_type], no_split_module_classes=[self.layer_type],
offload_folder=offload_folder, offload_folder=offload_folder,
dtype=torch_dtype, dtype=torch_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment