Commit 96757752 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 1210d8fe
......@@ -48,10 +48,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1):
global fused_adam_cuda, radix_decomp_cuda
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
radix_decomp_cuda = importlib.import_module("radix_decomp_cuda")
# To-Do: Add radix decomp args to fairseq adam optimizer
self._amp_scale_adjustment = amp_scale_adjustment
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment