megatron_v4.patch 23.4 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
diff --git a/.gitignore b/.gitignore
index 5955b349..ade0cd51 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,5 @@ build
 slurm*
 logs
 .vscode
+tests/*
+examples/*
diff --git a/build.sh b/build.sh
new file mode 100644
index 00000000..49d5361f
--- /dev/null
+++ b/build.sh
@@ -0,0 +1,4 @@
+#! /bin/bash
+
+export PYTHONPATH=$PYTHONPATH:$(pwd)
+pip3 install regex ninja
diff --git a/megatron/__init__.py b/megatron/__init__.py
index c35de282..60896b47 100644
--- a/megatron/__init__.py
+++ b/megatron/__init__.py
@@ -2,7 +2,7 @@
 
 import torch
 
-from .global_vars import get_args, get_retro_args
+from .global_vars import get_args, update_args, fork_args_namespace, get_retro_args
 from .global_vars import get_current_global_batch_size
 from .global_vars import get_num_microbatches
 from .global_vars import get_signal_handler
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 0ca8776e..9ef67624 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -59,6 +59,16 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
     return args
 
 def validate_args(args, defaults={}):
+    # Set input defaults.
+    for key in defaults:
+        if getattr(args, key, None) is not None:
+            if args.rank == 0 and defaults[key] != getattr(args, key):
+                print('WARNING: overriding default argument {key}:{v2} \
+                       with {key}:{v}'.format(key=key, v=defaults[key],
+                                              v2=getattr(args, key)),
+                      flush=True)
+
+        setattr(args, key, defaults[key])    
     # Tensor model parallel size.
     args.tensor_model_parallel_size = min(
         args.tensor_model_parallel_size, args.world_size)
@@ -125,19 +135,19 @@ def validate_args(args, defaults={}):
         args.recompute_granularity = 'selective'
     del args.recompute_activations
 
-    # Set input defaults.
-    for key in defaults:
-        # For default to be valid, it should not be provided in the
-        # arguments that are passed to the program. We check this by
-        # ensuring the arg is set to None.
-        if getattr(args, key, None) is not None:
-            if args.rank == 0:
-                print('WARNING: overriding default arguments for {key}:{v} \
-                       with {key}:{v2}'.format(key=key, v=defaults[key],
-                                               v2=getattr(args, key)),
-                                               flush=True)
-        else:
-            setattr(args, key, defaults[key])
+    # # Set input defaults.
+    # for key in defaults:
+    #     # For default to be valid, it should not be provided in the
+    #     # arguments that are passed to the program. We check this by
+    #     # ensuring the arg is set to None.
+    #     if getattr(args, key, None) is not None:
+    #         if args.rank == 0:
+    #             print('WARNING: overriding default arguments for {key}:{v} \
+    #                    with {key}:{v2}'.format(key=key, v=defaults[key],
+    #                                            v2=getattr(args, key)),
+    #                                            flush=True)
+    #     else:
+    #         setattr(args, key, defaults[key])
 
     # Batch size.
     assert args.micro_batch_size is not None
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index 29ee34df..fa590b16 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -130,32 +130,28 @@ def _batched_p2p_ops(
         send_prev_op = torch.distributed.P2POp(
             torch.distributed.isend,
             tensor_send_prev,
-            get_pipeline_model_parallel_prev_rank(),
-            group,
+            get_pipeline_model_parallel_prev_rank()
         )
         ops.append(send_prev_op)
     if tensor_recv_prev is not None:
         recv_prev_op = torch.distributed.P2POp(
             torch.distributed.irecv,
             tensor_recv_prev,
-            get_pipeline_model_parallel_prev_rank(),
-            group,
+            get_pipeline_model_parallel_prev_rank()
         )
         ops.append(recv_prev_op)
     if tensor_send_next is not None:
         send_next_op = torch.distributed.P2POp(
             torch.distributed.isend,
             tensor_send_next,
-            get_pipeline_model_parallel_next_rank(),
-            group,
+            get_pipeline_model_parallel_next_rank()
         )
         ops.append(send_next_op)
     if tensor_recv_next is not None:
         recv_next_op = torch.distributed.P2POp(
             torch.distributed.irecv,
             tensor_recv_next,
-            get_pipeline_model_parallel_next_rank(),
-            group,
+            get_pipeline_model_parallel_next_rank()
         )
         ops.append(recv_next_op)
     if len(ops) > 0:
diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py
index 992da781..2eb78d52 100644
--- a/megatron/core/pipeline_parallel/schedules.py
+++ b/megatron/core/pipeline_parallel/schedules.py
@@ -78,6 +78,8 @@ def get_forward_backward_func():
         transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
         in the config is True. Otherwise, each microbatch in the current global batch size must use
         this sequence length.
+    
+    hidden_size (int, required): hidden size of the model
 
     micro_batch_size (int, required): The number of sequences in a microbatch.
 
@@ -287,6 +289,7 @@ def forward_backward_no_pipelining(
     model: Union[torch.nn.Module, List[torch.nn.Module]],
     num_microbatches: int,
     seq_length: int,  # unused
+    hidden_size: int, # unused
     micro_batch_size: int,  # unused
     decoder_seq_length: int = None,  # unused
     forward_only: bool = False,
@@ -370,8 +373,10 @@ def forward_backward_pipelining_with_interleaving(
     data_iterator: Union[Iterator, List[Iterator]],
     model: Union[torch.nn.Module, List[torch.nn.Module]],
     num_microbatches: int,
-    seq_length: int,
-    micro_batch_size: int,
+    seq_length: int = None,
+    hidden_size: int = None,
+    micro_batch_size: int = None,
+    input_shapes: list = None,
     decoder_seq_length: int = None,
     forward_only: bool = False,
     collect_non_loss_data: bool = False,
@@ -457,7 +462,7 @@ def forward_backward_pipelining_with_interleaving(
             "Interleaving is not supported with a different decoder sequence length."
         )
 
-    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+    tensor_shape = [seq_length, micro_batch_size, hidden_size]
     if config.sequence_parallel:
         tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
 
@@ -944,6 +949,7 @@ def get_tensor_shapes(
     rank: int,
     model_type: ModelType,
     seq_length: int,
+    hidden_size: int,
     micro_batch_size: int,
     decoder_seq_length: int,
     config,
@@ -967,12 +973,12 @@ def get_tensor_shapes(
 
     if model_type == ModelType.encoder_and_decoder:
         if parallel_state.is_pipeline_stage_before_split(rank):
-            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+            tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
         else:
-            tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
-            tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+            tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
+            tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
     else:
-        tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+        tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
     return tensor_shapes
 
 
@@ -1050,8 +1056,10 @@ def forward_backward_pipelining_without_interleaving(
     data_iterator: Union[Iterator, List[Iterator]],
     model: Union[torch.nn.Module, List[torch.nn.Module]],
     num_microbatches: int,
-    seq_length: int,
-    micro_batch_size: int,
+    seq_length: int = None,
+    hidden_size: int = None,
+    micro_batch_size: int = None,
+    input_shapes: list = None,
     decoder_seq_length: int = None,
     forward_only: bool = False,
     collect_non_loss_data: bool = False,
@@ -1127,22 +1135,34 @@ def forward_backward_pipelining_without_interleaving(
     model_type = get_model_type(model)
 
     rank = parallel_state.get_pipeline_model_parallel_rank()
-    recv_tensor_shapes = get_tensor_shapes(
-        rank=rank - 1,
-        model_type=model_type,
-        seq_length=seq_length,
-        micro_batch_size=micro_batch_size,
-        decoder_seq_length=decoder_seq_length,
-        config=config,
-    )
-    send_tensor_shapes = get_tensor_shapes(
-        rank=rank,
-        model_type=model_type,
-        seq_length=seq_length,
-        micro_batch_size=micro_batch_size,
-        decoder_seq_length=decoder_seq_length,
-        config=config,
-    )
+
+    def get_recv_tensor_shapes(microbatch_id):
+        if input_shapes:
+            return [input_shapes[microbatch_id]]
+        recv_tensor_shapes = get_tensor_shapes(
+            rank=rank - 1,
+            model_type=model_type,
+            seq_length=seq_length,
+            hidden_size=hidden_size,
+            micro_batch_size=micro_batch_size,
+            decoder_seq_length=decoder_seq_length,
+            config=config,
+        )
+        return recv_tensor_shapes
+
+    def get_send_tensor_shapes(microbatch_id):
+        if input_shapes:
+            return [input_shapes[microbatch_id]]
+        send_tensor_shapes = get_tensor_shapes(
+            rank=rank,
+            model_type=model_type,
+            seq_length=seq_length,
+            hidden_size=hidden_size,
+            micro_batch_size=micro_batch_size,
+            decoder_seq_length=decoder_seq_length,
+            config=config,
+        )
+        return send_tensor_shapes
 
     # Input, output tensors only need to be saved when doing backward passes
     input_tensors = None
@@ -1163,7 +1183,12 @@ def forward_backward_pipelining_without_interleaving(
         else:
             checkpoint_activations_microbatch = None
 
+        # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+        #     print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup recv_forward begin...')
+        recv_tensor_shapes = get_recv_tensor_shapes(i)  # fwd recv shape
         input_tensor = recv_forward(recv_tensor_shapes, config)
+        # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+        #     print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup recv_forward end & forward begin...')
         output_tensor = forward_step(
             forward_step_func,
             data_iterator,
@@ -1175,7 +1200,13 @@ def forward_backward_pipelining_without_interleaving(
             collect_non_loss_data,
             checkpoint_activations_microbatch,
         )
+        # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+        #     print(f'rank {torch.cuda.current_device()}: output tensor shape = {output_tensor[0].shape}, send_tensor_shapes={send_tensor_shapes}')
+        #     print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup forward end & send_forward begin...')
+        send_tensor_shapes = get_send_tensor_shapes(i)  # fwd send shape
         send_forward(output_tensor, send_tensor_shapes, config)
+        # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+        #     print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup send_forward end...')        
 
         if not forward_only:
             input_tensors.append(input_tensor)
@@ -1186,11 +1217,16 @@ def forward_backward_pipelining_without_interleaving(
     # If all microbatches are run in warmup / cooldown phase, then no need to
     # receive this tensor here.
     if num_microbatches_remaining > 0:
-        input_tensor = recv_forward(recv_tensor_shapes, config)
+        # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+        #     print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches}: 1f1b recv_forward begin...')
+        recv_tensor_shapes = get_recv_tensor_shapes(num_warmup_microbatches)  # fwd recv shape
+        input_tensor = recv_forward(recv_tensor_shapes, config)      
 
     # Run 1F1B in steady state.
     for i in range(num_microbatches_remaining):
         last_iteration = i == (num_microbatches_remaining - 1)
+        next_forward_k = num_warmup_microbatches + i + 1
+        backward_k = i
 
         # Decide to checkpoint all layers' activations of the current micro-batch
         if max_outstanding_backprops is not None:
@@ -1199,7 +1235,8 @@ def forward_backward_pipelining_without_interleaving(
             ) >= config.num_microbatches_with_partial_activation_checkpoints
         else:
             checkpoint_activations_microbatch = None
-
+        # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+        #     print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b recv_forward end & forward begin...') 
         output_tensor = forward_step(
             forward_step_func,
             data_iterator,
@@ -1213,12 +1250,23 @@ def forward_backward_pipelining_without_interleaving(
         )
 
         if forward_only:
+            # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+            #     print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b forward end & send forward begin...') 
+            send_tensor_shapes = get_send_tensor_shapes(next_forward_k - 1)  # fwd send shape
             send_forward(output_tensor, send_tensor_shapes, config)
 
             if not last_iteration:
+                # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+                #     print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b send forward end & recv forward begin...')
+                recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k)  # fwd recv shape
                 input_tensor = recv_forward(recv_tensor_shapes, config)
+            else:
+                pass
+                # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+                #     print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b send forward end...')                
 
         else:
+            send_tensor_shapes = get_send_tensor_shapes(backward_k)  # bwd recv shape
             output_tensor_grad = send_forward_recv_backward(
                 output_tensor, send_tensor_shapes, config
             )
@@ -1245,8 +1293,10 @@ def forward_backward_pipelining_without_interleaving(
 
             if last_iteration:
                 input_tensor = None
+                recv_tensor_shapes = get_recv_tensor_shapes(backward_k)  # bwd send shape
                 send_backward(input_tensor_grad, recv_tensor_shapes, config)
             else:
+                recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k)  # fwd recv shape
                 input_tensor = send_backward_recv_forward(
                     input_tensor_grad, recv_tensor_shapes, config
                 )
@@ -1254,7 +1304,7 @@ def forward_backward_pipelining_without_interleaving(
     # Run cooldown backward passes.
     if not forward_only:
         for i in range(num_warmup_microbatches):
-
+            backward_k = num_microbatches_remaining + i
             # Enable async grad reduction in the last backward pass
             # Note: If grad sync function is provided, only enable
             # async grad reduction in first pipeline stage. Other
@@ -1267,12 +1317,14 @@ def forward_backward_pipelining_without_interleaving(
             input_tensor = input_tensors.pop(0)
             output_tensor = output_tensors.pop(0)
 
+            send_tensor_shapes = get_send_tensor_shapes(backward_k)  # bwd recv shape
             output_tensor_grad = recv_backward(send_tensor_shapes, config)
 
             input_tensor_grad = backward_step(
                 input_tensor, output_tensor, output_tensor_grad, model_type, config
             )
 
+            recv_tensor_shapes = get_recv_tensor_shapes(backward_k)  # bwd send shape
             send_backward(input_tensor_grad, recv_tensor_shapes, config)
 
         # Launch any remaining grad reductions.
diff --git a/megatron/core/utils.py b/megatron/core/utils.py
index d4e042b2..c480d14e 100644
--- a/megatron/core/utils.py
+++ b/megatron/core/utils.py
@@ -55,8 +55,9 @@ def get_model_type(model):
     return get_attr_wrapped_model(model, 'model_type')
 
 
+# walkaround: get_model_config to get megatron config (ModelParallelConfig)
 def get_model_config(model):
-    return get_attr_wrapped_model(model, 'config', allow_none=False)
+    return get_attr_wrapped_model(model, 'megatron_config', allow_none=False)
 
 
 class GlobalMemoryBuffer:
diff --git a/megatron/global_vars.py b/megatron/global_vars.py
index b1b4b043..9e23dea5 100644
--- a/megatron/global_vars.py
+++ b/megatron/global_vars.py
@@ -21,11 +21,48 @@ _GLOBAL_ADLR_AUTORESUME = None
 _GLOBAL_TIMERS = None
 _GLOBAL_SIGNAL_HANDLER = None
 
-def get_args():
+DEFAULT_NAMESPACE = 'default'
+import contextlib
+
+@contextlib.contextmanager
+def fork_args_namespace(namespace):
+    """
+    Usage example:
+        update_args('vit', vit_config)
+        with fork_args_namespace('vit'):
+            do vit stuff here
+    """
+    # Check if we have added the args namespace
+    if namespace not in _GLOBAL_ARGS:
+        raise Exception('args namespace {} is not added'.format(namespace))
+    # Store current args namespace.
+    tmp = _GLOBAL_ARGS[DEFAULT_NAMESPACE]
+    # Set args namespace to the desired one
+    _GLOBAL_ARGS[DEFAULT_NAMESPACE] = _GLOBAL_ARGS[namespace]
+    # Do the stuff we wanted to do.
+    try:
+        yield
+    finally:
+        _GLOBAL_ARGS[DEFAULT_NAMESPACE] = tmp
+
+def get_args(namespace=DEFAULT_NAMESPACE):
     """Return arguments."""
     _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
-    return _GLOBAL_ARGS
+    return _GLOBAL_ARGS[namespace]
 
+def set_args(args):
+    global _GLOBAL_ARGS
+    if _GLOBAL_ARGS is None:
+        _GLOBAL_ARGS = {}
+    _GLOBAL_ARGS[DEFAULT_NAMESPACE] = args
+
+def update_args(namespace, args):
+    _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
+    if namespace not in _GLOBAL_ARGS:
+        import copy
+        _GLOBAL_ARGS[namespace] = copy.deepcopy(_GLOBAL_ARGS[DEFAULT_NAMESPACE])
+    for k, v in args.items():
+        setattr(_GLOBAL_ARGS[namespace], k, v)
 
 def get_retro_args():
     """Return retro arguments."""
@@ -87,7 +124,7 @@ def _set_signal_handler():
 
 
 
-def set_global_variables(args, build_tokenizer=True):
+def set_global_variables(args):
     """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
 
     assert args is not None
@@ -96,7 +133,7 @@ def set_global_variables(args, build_tokenizer=True):
     set_args(args)
 
     _build_num_microbatches_calculator(args)
-    if build_tokenizer:
+    if args.vocab_file:
         _ = _build_tokenizer(args)
     _set_tensorboard_writer(args)
     _set_wandb_writer(args)
@@ -107,11 +144,6 @@ def set_global_variables(args, build_tokenizer=True):
         _set_signal_handler()
 
 
-def set_args(args):
-    global _GLOBAL_ARGS
-    _GLOBAL_ARGS = args
-
-
 def set_retro_args(retro_args):
     global _GLOBAL_RETRO_ARGS
     _GLOBAL_RETRO_ARGS = retro_args
diff --git a/megatron/initialize.py b/megatron/initialize.py
index fb7866ab..01999622 100644
--- a/megatron/initialize.py
+++ b/megatron/initialize.py
@@ -39,7 +39,7 @@ def initialize_megatron(
     if not allow_no_cuda:
         # Make sure cuda is available.
         assert torch.cuda.is_available(), "Megatron requires CUDA."
-
+    print('use open-source megatron initialize...')
     # Parse arguments
     args = parse_args(extra_args_provider, ignore_unknown_args)
 
diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py
index c91a674e..bcb7bd7e 100644
--- a/megatron/model/fused_layer_norm.py
+++ b/megatron/model/fused_layer_norm.py
@@ -81,7 +81,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
     if self.no_persist_layer_norm:
         assert FusedLayerNormAffineFunction is not None, \
             "FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex"
-        return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
+        return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, False)
     else:
         output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
 
diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py
index a04ae478..b64d22a5 100644
--- a/megatron/optimizer/distrib_optimizer.py
+++ b/megatron/optimizer/distrib_optimizer.py
@@ -366,7 +366,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
 
     def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                  check_for_nan_in_grad, params_have_main_grad, fp16,
-                 bf16, params_dtype, grad_scaler, models):
+                 bf16, params_dtype, grad_scaler, models, overlap_param_gather=False):
         """
         See top of class definition for argument descriptions.
 
@@ -382,8 +382,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
             check_for_nan_in_grad, params_have_main_grad,
             fp16, bf16, params_dtype, grad_scaler, models)
 
-        assert isinstance(optimizer, Adam), \
-            "Only Adam currently supported, due to checkpointing requirements."
+        # assert isinstance(optimizer, Adam), \
+        #     "Only Adam currently supported, due to checkpointing requirements."
+
+        if not isinstance(optimizer, Adam):
+            print("WARNING: the optimizer type is not Adam, and now Only Adam currently support checkpointing requirements!")
 
         # Model grad buffer ranges.
         self.model_gbuf_ranges = []
@@ -476,7 +479,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
             self.param_buffer_copied.append(False)
         self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map)
 
-        self.overlap_param_gather = get_args().overlap_param_gather
+        self.overlap_param_gather = overlap_param_gather
         if self.overlap_param_gather:
             self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(
                 self._make_forward_pre_hook())
diff --git a/megatron/training.py b/megatron/training.py
index 36f6c52e..73664509 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -430,6 +430,7 @@ def train_step(forward_step_func, data_iterator,
         model=model,
         num_microbatches=get_num_microbatches(),
         seq_length=args.seq_length,
+        hidden_size=args.hidden_size,
         micro_batch_size=args.micro_batch_size,
         decoder_seq_length=args.decoder_seq_length,
         forward_only=False)
diff --git a/tools/prebuild_kernels.py b/tools/prebuild_kernels.py
new file mode 100644
index 00000000..6f891b9e
--- /dev/null
+++ b/tools/prebuild_kernels.py
@@ -0,0 +1,13 @@
+import os
+from megatron.fused_kernels import load
+
+
+class FakeArgs:
+    rank = 0
+
+
+# 7.0 for V100
+# 8.0 for A100/A800
+os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX;8.0+PTX"
+
+load(FakeArgs)
\ No newline at end of file