Commit 97913a1f authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging_optimization' into 'staging'

Staging optimization level 1

See merge request ADLR/megatron-lm!54
parents be5300b7 a4489ffb
......@@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule
......@@ -320,10 +322,13 @@ class FP16_Optimizer(object):
def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.mul_(1. / self.loss_scale)
grads = [p.grad for p in group['params'] if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2):
"""
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
......
......@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
......@@ -166,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False):
......
......@@ -14,6 +14,10 @@
# limitations under the License.
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility.
......@@ -57,7 +61,12 @@ class LossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
......@@ -180,7 +189,12 @@ class DynamicLossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
......
......@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
......
......@@ -42,8 +42,7 @@ def get_batch(context_tokens):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
args.eod_mask_loss)
return tokens, attention_mask, position_ids
......
......@@ -227,7 +227,7 @@ def backward_step(optimizer, model, loss):
timers = get_timers()
# Backward pass.
optimizer.zero_grad()
optimizer.zero_grad(set_grads_to_None=True)
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
......
......@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
fp16):
eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
......@@ -170,8 +169,7 @@ def get_ltor_masks_and_position_ids(data,
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert
if fp16:
attention_mask = attention_mask.half()
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
......@@ -65,8 +65,7 @@ def get_batch(data_iterator):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
args.eod_mask_loss)
return tokens, labels, loss_mask, attention_mask, position_ids
......
......@@ -71,8 +71,7 @@ def process_batch(batch):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
args.eod_mask_loss)
return tokens, labels, attention_mask, position_ids, loss_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment