Commit 53931b8b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 8acbbe25
......@@ -246,14 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.wd_incr_style == 'constant':
assert args.start_wd is None
assert args.end_wd is None
args.start_wd = args.weight_decay
args.end_wd = args.weight_decay
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_wd is not None
assert args.end_wd is not None
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
......@@ -404,11 +404,11 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-wd', type=float,
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-wd', type=float,
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--wd-incr-style', type=str, default='constant',
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
......
......@@ -49,20 +49,20 @@ class DropPath(MegatronModule):
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return x
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = hidden_state.div(keep_prob) * random_tensor
return output
......@@ -437,7 +437,6 @@ class ParallelTransformerLayer(MegatronModule):
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.drop_path_rate = drop_path_rate
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
......@@ -460,7 +459,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
......@@ -504,7 +503,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = hidden_states
if self.drop_path_rate == 0.0:
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
......@@ -564,7 +563,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
if self.drop_path_rate == 0.0:
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
......@@ -608,7 +607,7 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, args.num_layers)]
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
# Transformer layers.
def build_layer(layer_number):
......@@ -618,7 +617,7 @@ class ParallelTransformer(MegatronModule):
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.dpr[layer_number - 1])
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
......
......@@ -341,9 +341,9 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps=warmup_steps,
decay_steps=decay_steps,
decay_style=args.lr_decay_style,
start_wd=args.start_wd,
end_wd=args.end_wd,
wd_incr_style=args.wd_incr_style,
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment