Commit 7e810e41 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent d8c85650
...@@ -128,8 +128,8 @@ class AnnealingLR(object): ...@@ -128,8 +128,8 @@ class AnnealingLR(object):
new_lr = self.get_lr() new_lr = self.get_lr()
new_wd = self.get_wd() new_wd = self.get_wd()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr * group['lr_mult'] group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group['wd_mult'] group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self): def state_dict(self):
......
...@@ -608,7 +608,7 @@ class ParallelTransformer(MegatronModule): ...@@ -608,7 +608,7 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers)] self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, args.num_layers)]
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment