Commit 99410264 authored by Mohammad's avatar Mohammad
Browse files

added multi-tensor-apply to fp16

parent 2467ae15
...@@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -22,6 +22,8 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule from megatron.module import MegatronModule
...@@ -320,10 +322,13 @@ class FP16_Optimizer(object): ...@@ -320,10 +322,13 @@ class FP16_Optimizer(object):
def _downscale_master(self): def _downscale_master(self):
if self.loss_scale != 1.0: if self.loss_scale != 1.0:
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
for param in group['params']: grads = [p.grad for p in group['params'] if p.grad is not None]
if param.grad is not None: _overflow_buf = torch.cuda.IntTensor([0])
param.grad.data.mul_(1. / self.loss_scale) multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2): def clip_master_grads(self, max_norm, norm_type=2):
""" """
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
......
...@@ -18,6 +18,9 @@ import torch.nn as nn ...@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu from megatron import mpu
...@@ -169,6 +172,13 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): ...@@ -169,6 +172,13 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
master.grad.data.copy_(model.grad.data) master.grad.data.copy_(model.grad.data)
else: else:
master.grad = None master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False): def master_params_to_model_params(model_params, master_params, flat_master=False):
......
...@@ -57,7 +57,12 @@ class LossScaler: ...@@ -57,7 +57,12 @@ class LossScaler:
return self.cur_scale return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale scaled_loss = loss * self.loss_scale
...@@ -180,7 +185,12 @@ class DynamicLossScaler: ...@@ -180,7 +185,12 @@ class DynamicLossScaler:
return self.cur_scale return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale scaled_loss = loss * self.loss_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment