"torchvision/csrc/cpu/decoder/cc_stream.cpp" did not exist on "f2600c2e6ac0d3f0dee2345c821bade90b2d9328"
Commit f8273a0c authored by Casper Hansen's avatar Casper Hansen
Browse files

Multi-GPU support for quantized models

parent 7cf0d987
......@@ -297,21 +297,29 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights()
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Load model weights
if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
model = load_checkpoint_and_dispatch(
model,
model_filename,
device_map=device_map,
no_split_module_classes=[self.layer_type]
)
if fuse_layers:
self.fuse_layers(model)
from awq.utils.utils import simple_dispatch_model
model = simple_dispatch_model(model, device_map)
else:
# If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# Load model weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment