Commit e6925e6c authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 8ed8eaac
......@@ -71,7 +71,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda
global distributed_lamb_cuda
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
......@@ -286,8 +286,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i
self._contrib_max_param_i = param_i
self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)
if len(self._contrib_model_param_for_norm_fp16) == 0: len(self._contrib_model_param_for_norm_fp16 = None
if len(self._contrib_model_param_for_norm_fp32) == 0: len(self._contrib_model_param_for_norm_fp32 = None
if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None
if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._contrib_model_param_for_norm_is_fp16 = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment