"docs/vscode:/vscode.git/clone" did not exist on "3f8146a7733171bb769e37ca453f2a7974973ef8"
Commit 8ed8eaac authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Use correct names for mt lamb cuda kernels

parent 45388d48
......@@ -72,8 +72,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
# FIXME: Import multi_tensor_lamb_* kernels instead
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
......@@ -90,13 +89,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
import amp_C
self.multi_tensor_lamb_compute_update_term = amp_C.multi_tensor_distopt_lamb_
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._last_step = False
self._overlap_reductions = overlap_reductions
......@@ -423,25 +419,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [self._contrib_model_param_for_norm_fp16], True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment