Unverified Commit c1f209da authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

[ZeRO] Fixes issue with embedding resize (#16093)



* gather z3 params for new_lm_head

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent ae2dd42b
......@@ -892,7 +892,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=0):
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if torch.distributed.get_rank() == 0:
# Copy old lm head weights to new lm head
if not transposed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment