Unverified Commit ac2707e8 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Revert "fixes to properly shard FSDP across cpu and meta for...

Revert "fixes to properly shard FSDP across cpu and meta for cpu_effcient_loading for prequantized 4bit (#32276)" (#32477)

* Revert "fixes to properly shard FSDP across cpu and meta for cpu_efficient_loading for prequantized 4bit (#32276)"

This reverts commit 62c60a30

.

We uncovered an issue with this change that caused our training runs to hang.

* `is_torchdynamo_compiling` -- cast a wide exception net (#32476)

* cast a wide net

* make fix-copies with a few manual changes

* add copied from

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 4fdc7020
...@@ -933,8 +933,6 @@ def _load_state_dict_into_meta_model( ...@@ -933,8 +933,6 @@ def _load_state_dict_into_meta_model(
) )
) )
): ):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
# For backward compatibility with older versions of `accelerate` and for non-quantized params # For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else: else:
...@@ -945,10 +943,7 @@ def _load_state_dict_into_meta_model( ...@@ -945,10 +943,7 @@ def _load_state_dict_into_meta_model(
if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name) value = getattr(module, tensor_name)
param_to = "cpu" value = type(value)(value.data.to("cpu"), **value.__dict__)
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
setattr(module, tensor_name, value) setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return # TODO: consider removing used param_parts from state_dict before return
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version from packaging import version
...@@ -208,16 +207,11 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -208,16 +207,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
if unexpected_keys is not None and k in unexpected_keys: if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k) unexpected_keys.remove(k)
param_kwargs = {}
sig = inspect.signature(bnb.nn.Params4bit.from_prequantized)
if "module" in sig.parameters:
param_kwargs["module"] = module
new_value = bnb.nn.Params4bit.from_prequantized( new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value, data=param_value,
quantized_stats=quantized_stats, quantized_stats=quantized_stats,
requires_grad=False, requires_grad=False,
device=target_device, device=target_device,
**param_kwargs,
) )
else: else:
new_value = param_value.to("cpu") new_value = param_value.to("cpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment