"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "b8be1bc7b663294a121194e51aeebad40c31d60e"
Unverified Commit 29a2c5d1 authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

Resolves [BUG] 'GatheredParameters' object is not callable (#9614)



* gatherparams bug

* calling context lib object

* fix

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 0d935df6
...@@ -24,6 +24,9 @@ from .utils import ( ...@@ -24,6 +24,9 @@ from .utils import (
if is_transformers_available(): if is_transformers_available():
import transformers import transformers
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
if is_peft_available(): if is_peft_available():
from peft import set_peft_model_state_dict from peft import set_peft_model_state_dict
...@@ -442,15 +445,13 @@ class EMAModel: ...@@ -442,15 +445,13 @@ class EMAModel:
self.cur_decay_value = decay self.cur_decay_value = decay
one_minus_decay = 1 - decay one_minus_decay = 1 - decay
context_manager = contextlib.nullcontext context_manager = contextlib.nullcontext()
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
if self.foreach: if self.foreach:
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
with context_manager(): with context_manager:
params_grad = [param for param in parameters if param.requires_grad] params_grad = [param for param in parameters if param.requires_grad]
s_params_grad = [ s_params_grad = [
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
...@@ -472,7 +473,7 @@ class EMAModel: ...@@ -472,7 +473,7 @@ class EMAModel:
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
with context_manager(): with context_manager:
if param.requires_grad: if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param)) s_param.sub_(one_minus_decay * (s_param - param))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment