Unverified Commit 62c60a30 authored by Wing Lian's avatar Wing Lian Committed by GitHub
Browse files

fixes to properly shard FSDP across cpu and meta for cpu_efficient_loading for...

fixes to properly shard FSDP across cpu and meta for cpu_efficient_loading for prequantized 4bit (#32276)
parent 16271080
...@@ -932,6 +932,8 @@ def _load_state_dict_into_meta_model( ...@@ -932,6 +932,8 @@ def _load_state_dict_into_meta_model(
) )
) )
): ):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
# For backward compatibility with older versions of `accelerate` and for non-quantized params # For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else: else:
...@@ -942,7 +944,10 @@ def _load_state_dict_into_meta_model( ...@@ -942,7 +944,10 @@ def _load_state_dict_into_meta_model(
if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name) value = getattr(module, tensor_name)
value = type(value)(value.data.to("cpu"), **value.__dict__) param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
setattr(module, tensor_name, value) setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return # TODO: consider removing used param_parts from state_dict before return
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version from packaging import version
...@@ -199,11 +200,16 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -199,11 +200,16 @@ class Bnb4BitHfQuantizer(HfQuantizer):
if unexpected_keys is not None and k in unexpected_keys: if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k) unexpected_keys.remove(k)
param_kwargs = {}
sig = inspect.signature(bnb.nn.Params4bit.from_prequantized)
if "module" in sig.parameters:
param_kwargs["module"] = module
new_value = bnb.nn.Params4bit.from_prequantized( new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value, data=param_value,
quantized_stats=quantized_stats, quantized_stats=quantized_stats,
requires_grad=False, requires_grad=False,
device=target_device, device=target_device,
**param_kwargs,
) )
else: else:
new_value = param_value.to("cpu") new_value = param_value.to("cpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment