"...resnet50_tensorflow.git" did not exist on "fb996cbb3a007e091a62f9111f2b425f70a8debb"
Unverified Commit 22fc93c4 authored by cavdard's avatar cavdard Committed by GitHub
Browse files

Changes in create_optimizer to support tensor parallelism with SMP (#16880)



* changes in create optimizer to support tensor parallelism with SMP

* Update src/transformers/trainer.py

Convert if check to one line.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarCavdar <dcavdar@a07817b12d7e.ant.amazon.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 99c8226b
...@@ -843,16 +843,18 @@ class Trainer: ...@@ -843,16 +843,18 @@ class Trainer:
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass. Trainer's init through `optimizers`, or subclass and override this method in a subclass.
""" """
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [nn.LayerNorm]) decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name] decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{ {
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters], "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay, "weight_decay": self.args.weight_decay,
}, },
{ {
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters], "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
...@@ -872,7 +874,7 @@ class Trainer: ...@@ -872,7 +874,7 @@ class Trainer:
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
for module in self.model.modules(): for module in opt_model.modules():
if isinstance(module, nn.Embedding): if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32}) manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.debug(f"bitsandbytes: will optimize {module} in fp32")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment