Unverified Commit 00bf4427 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[FIX] `offload_weight()` takes from 3 to 4 positional arguments but 5 were given (#29457)

* use require_torch_gpu

* enable on XPU

* fix
parent 7b01579f
...@@ -796,7 +796,7 @@ def _load_state_dict_into_meta_model( ...@@ -796,7 +796,7 @@ def _load_state_dict_into_meta_model(
if not is_safetensors: if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index) offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None: elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, model, state_dict_folder, state_dict_index) state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif ( elif (
hf_quantizer is None hf_quantizer is None
or (not hf_quantizer.requires_parameters_quantization) or (not hf_quantizer.requires_parameters_quantization)
......
...@@ -765,7 +765,7 @@ class ModelUtilsTest(TestCasePlus): ...@@ -765,7 +765,7 @@ class ModelUtilsTest(TestCasePlus):
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@require_torch_accelerator @require_torch_gpu
def test_from_pretrained_disk_offload_task_model(self): def test_from_pretrained_disk_offload_task_model(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2") model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
device_map = { device_map = {
...@@ -808,7 +808,7 @@ class ModelUtilsTest(TestCasePlus): ...@@ -808,7 +808,7 @@ class ModelUtilsTest(TestCasePlus):
@require_accelerate @require_accelerate
@mark.accelerate_tests @mark.accelerate_tests
@require_torch_accelerator @require_torch_gpu
def test_from_pretrained_disk_offload_derived_to_base_model(self): def test_from_pretrained_disk_offload_derived_to_base_model(self):
derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") derived_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment