Commit 7e810e41 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent d8c85650
......@@ -128,8 +128,8 @@ class AnnealingLR(object):
new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups:
group['lr'] = new_lr * group['lr_mult']
group['weight_decay'] = new_wd * group['wd_mult']
group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self):
......
......@@ -608,7 +608,7 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers)]
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, args.num_layers)]
# Transformer layers.
def build_layer(layer_number):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment