Unverified Commit 72298178 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

fix max_memory for bnb (#25842)

parent f73c2097
...@@ -96,6 +96,7 @@ if is_accelerate_available(): ...@@ -96,6 +96,7 @@ if is_accelerate_available():
check_tied_parameters_on_same_device, check_tied_parameters_on_same_device,
find_tied_parameters, find_tied_parameters,
get_balanced_memory, get_balanced_memory,
get_max_memory,
load_offloaded_weights, load_offloaded_weights,
offload_weight, offload_weight,
save_offload_index, save_offload_index,
...@@ -3093,7 +3094,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3093,7 +3094,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
max_memory=max_memory, max_memory=max_memory,
**device_map_kwargs, **device_map_kwargs,
) )
else:
max_memory = get_max_memory(max_memory)
if getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
device_map_kwargs["max_memory"] = max_memory device_map_kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment