Commit f8273a0c authored by Casper Hansen's avatar Casper Hansen
Browse files

Multi-GPU support for quantized models

parent 7cf0d987
...@@ -297,21 +297,29 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -297,21 +297,29 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Load model weights # Load model weights
if is_quantized: if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type]) model = load_checkpoint_and_dispatch(
model,
model_filename,
device_map=device_map,
no_split_module_classes=[self.layer_type]
)
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
from awq.utils.utils import simple_dispatch_model
model = simple_dispatch_model(model, device_map)
else: else:
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model del model
# Load model weights # Load model weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment