"csrc/vscode:/vscode.git/clone" did not exist on "06d9038fa9466ede6ce6a6706beaa90d389e788a"
initialize.py 18.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2
3

"""Megatron initialization."""
xingjinliang's avatar
xingjinliang committed
4
import logging
Mohammad's avatar
Mohammad committed
5
import os
wangxj's avatar
wangxj committed
6
import random
7
import time
xingjinliang's avatar
xingjinliang committed
8
import warnings
wangxj's avatar
wangxj committed
9
from datetime import timedelta
Mohammad's avatar
Mohammad committed
10

11
import numpy as np
Mohammad's avatar
Mohammad committed
12
13
import torch

14
from megatron.core import mpu, tensor_parallel
xingjinliang's avatar
xingjinliang committed
15
16
17
from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train
from megatron.core.fusions.fused_bias_gelu import bias_gelu
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
wangxj's avatar
wangxj committed
18
19
20
21
22
23
24
from megatron.core.parallel_state import create_group
from megatron.core.rerun_state_machine import (
    RerunDiagnostic,
    RerunErrorInjector,
    RerunMode,
    initialize_rerun_state_machine,
)
xingjinliang's avatar
xingjinliang committed
25
from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version
wangxj's avatar
wangxj committed
26
27
28
29
30
31
32
from megatron.legacy import fused_kernels
from megatron.training import get_adlr_autoresume, get_args, get_tensorboard_writer
from megatron.training.arguments import parse_args, validate_args
from megatron.training.async_utils import init_persistent_async_worker
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.training.yaml_arguments import validate_yaml
xingjinliang's avatar
xingjinliang committed
33
34

logger = logging.getLogger(__name__)
35

Mohammad's avatar
Mohammad committed
36

liangjing's avatar
v1  
liangjing committed
37
38
39
40
41
def initialize_megatron(
    extra_args_provider=None,
    args_defaults={},
    ignore_unknown_args=False,
    allow_no_cuda=False,
xingjinliang's avatar
xingjinliang committed
42
43
    skip_mpu_initialization=False,
    get_embedding_ranks=None,
wangxj's avatar
wangxj committed
44
    get_position_embedding_ranks=None,
liangjing's avatar
v1  
liangjing committed
45
):
Mohammad's avatar
Mohammad committed
46
    """Set global variables, initialize distributed, and
Raul Puri's avatar
Raul Puri committed
47
    set autoresume and random seeds.
liangjing's avatar
v1  
liangjing committed
48
49
    `allow_no_cuda` should not be set unless using megatron for cpu only
    data processing. In general this arg should not be set unless you know
50
    what you are doing.
liangjing's avatar
v1  
liangjing committed
51
    Returns a function to finalize distributed env initialization
Boris Fomitchev's avatar
Boris Fomitchev committed
52
    (optionally, only when args.lazy_mpu_init == True)
53
    """
Raul Puri's avatar
Raul Puri committed
54
55
    if not allow_no_cuda:
        # Make sure cuda is available.
liangjing's avatar
v1  
liangjing committed
56
        assert torch.cuda.is_available(), "Megatron requires CUDA."
Mohammad's avatar
Mohammad committed
57

58
59
60
    # Parse arguments
    args = parse_args(extra_args_provider, ignore_unknown_args)

xingjinliang's avatar
xingjinliang committed
61
62
63
64
65
66
    # Prep for checkpoint conversion.
    if args.ckpt_convert_format is not None:
        assert args.ckpt_convert_save is not None
        assert args.load is not None
        args.exit_on_missing_checkpoint = True

liangjing's avatar
v1  
liangjing committed
67
    if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
xingjinliang's avatar
xingjinliang committed
68
        assert args.load is not None, "--use-checkpoint-args requires --load argument"
wangxj's avatar
wangxj committed
69
70
71
72
73
        assert args.non_persistent_ckpt_type != "local", (
            "--use-checkpoint-args is not supported with --non_persistent_ckpt_type=local. "
            "Two-stage checkpoint loading is not implemented, and all arguments must be defined "
            "before initializing LocalCheckpointManager."
        )
74
75
        load_args_from_checkpoint(args)

wangxj's avatar
wangxj committed
76
77
78
    if args.async_save and args.use_persistent_ckpt_worker:
        init_persistent_async_worker()

xingjinliang's avatar
xingjinliang committed
79
80
81
82
83
    if args.yaml_cfg is not None:
        args = validate_yaml(args, args_defaults)
    else:
        validate_args(args, args_defaults)

84
    # set global args, build tokenizer, and set adlr-autoresume,
Mohammad's avatar
Mohammad committed
85
    # tensorboard-writer, and timers.
86
    set_global_variables(args)
Mohammad's avatar
Mohammad committed
87

xingjinliang's avatar
xingjinliang committed
88
89
90
91
92
    # set logging level
    setup_logging()

    # init rerun state
    def state_save_func():
wangxj's avatar
wangxj committed
93
94
        return {'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}

xingjinliang's avatar
xingjinliang committed
95
96
97
98
99
100
101
102
103
104
105
106
107
    def state_restore_func(state_dict):
        if state_dict['rng_tracker_states']:
            tensor_parallel.get_cuda_rng_tracker().set_states(state_dict['rng_tracker_states'])

    args = get_args()
    initialize_rerun_state_machine(
        state_save_func=state_save_func,
        state_restore_func=state_restore_func,
        mode=RerunMode(args.rerun_mode),
        error_injector=RerunErrorInjector(
            error_injection_rate=args.error_injection_rate,
            error_injection_type=RerunDiagnostic(args.error_injection_type),
        ),
wangxj's avatar
wangxj committed
108
        result_rejected_tracker_filename=args.result_rejected_tracker_filename,
xingjinliang's avatar
xingjinliang committed
109
110
    )

111
    # torch.distributed initialization
112
    def finish_mpu_init():
113
114
        args = get_args()
        # Pytorch distributed.
xingjinliang's avatar
xingjinliang committed
115
        _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks)
liangjing's avatar
v1  
liangjing committed
116

117
118
        # Random seeds for reproducibility.
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
119
            print("> setting random seeds to {} ...".format(args.seed))
wangxj's avatar
wangxj committed
120
121
122
123
124
125
        _set_random_seed(
            args.seed,
            args.data_parallel_random_init,
            args.te_rng_tracker,
            args.inference_rng_tracker,
        )
Mohammad's avatar
Mohammad committed
126

xingjinliang's avatar
xingjinliang committed
127
128
129
    if skip_mpu_initialization:
        return None

Mohammad's avatar
Mohammad committed
130
    args = get_args()
liangjing's avatar
v1  
liangjing committed
131
    if args.lazy_mpu_init:
132
        # TODO is this still a necessary option?
liangjing's avatar
v1  
liangjing committed
133
        args.use_cpu_initialization = True
134
        # delayed initialization of DDP-related stuff
135
136
        # We only set basic DDP globals
        mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
137
138
        # and return function for external DDP manager
        # to call when it has DDP initialized
139
        mpu.set_tensor_model_parallel_rank(args.rank)
140
        return finish_mpu_init
141
    else:
142
143
        # Megatron's MPU is the master. Complete initialization right away.
        finish_mpu_init()
144

145
146
        # Autoresume.
        _init_autoresume()
mshoeybi's avatar
mshoeybi committed
147

148
149
150
        # Compile dependencies.
        _compile_dependencies()

xingjinliang's avatar
xingjinliang committed
151
        if args.tp_comm_overlap:
wangxj's avatar
wangxj committed
152
153
            # TODO: Should this be activated with just decoder-tp-comm-overlap too?
            _initialize_tp_communicators()
xingjinliang's avatar
xingjinliang committed
154

155
156
        # No continuation function
        return None
157
158
159
160
161
162
163
164
165
166
167
168


def _compile_dependencies():

    args = get_args()

    # =========================
    # Compile dataset C++ code.
    # =========================
    # TODO: move this to ninja
    if torch.distributed.get_rank() == 0:
        start_time = time.time()
liangjing's avatar
v1  
liangjing committed
169
        print("> compiling dataset index builder ...")
xingjinliang's avatar
xingjinliang committed
170
        from megatron.core.datasets.utils import compile_helpers
liangjing's avatar
v1  
liangjing committed
171

xingjinliang's avatar
xingjinliang committed
172
        compile_helpers()
liangjing's avatar
v1  
liangjing committed
173
174
175
176
177
        print(
            ">>> done with dataset index builder. Compilation time: {:.3f} "
            "seconds".format(time.time() - start_time),
            flush=True,
        )
178
179
180
181
182
183
184

    # ==================
    # Load fused kernels
    # ==================

    # Custom kernel constraints check.
    seq_len = args.seq_length
liangjing's avatar
v1  
liangjing committed
185
186
187
    attn_batch_size = (
        args.num_attention_heads / args.tensor_model_parallel_size
    ) * args.micro_batch_size
188
189
    # Constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
liangjing's avatar
v1  
liangjing committed
190
    custom_kernel_constraint = (
wangxj's avatar
wangxj committed
191
        seq_len > 16 and seq_len <= 16384 and seq_len % 4 == 0 and attn_batch_size % 4 == 0
liangjing's avatar
v1  
liangjing committed
192
    )
193
    # Print a warning.
wangxj's avatar
wangxj committed
194
    if not ((args.fp16 or args.bf16) and custom_kernel_constraint and args.masked_softmax_fusion):
195
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
196
197
198
199
200
201
202
            print(
                "WARNING: constraints for invoking optimized"
                " fused softmax kernel are not met. We default"
                " back to unfused kernel invocations.",
                flush=True,
            )

203
    # Always build on rank zero first.
wangxj's avatar
wangxj committed
204
205
206
207
208
209
210
211
    # if torch.distributed.get_rank() == 0:
        # start_time = time.time()
        # print("> compiling and loading fused kernels ...", flush=True)
        # fused_kernels.load(args)
    #     torch.distributed.barrier()
    # else:
    #     torch.distributed.barrier()
        # fused_kernels.load(args)
212
213
214
215
216
217
    # Simple barrier to make sure all ranks have passed the
    # compilation phase successfully before moving on to the
    # rest of the program. We think this might ensure that
    # the lock is released.
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
liangjing's avatar
v1  
liangjing committed
218
219
220
221
222
        print(
            ">>> done with compiling and loading fused kernels. "
            "Compilation time: {:.3f} seconds".format(time.time() - start_time),
            flush=True,
        )
223

wangxj's avatar
wangxj committed
224

xingjinliang's avatar
xingjinliang committed
225
def _initialize_tp_communicators():
wangxj's avatar
wangxj committed
226
227
    """initializing the communicators with user buffers for high-performance tensor-model-parallel
    communication overlap"""
xingjinliang's avatar
xingjinliang committed
228
229

    try:
wangxj's avatar
wangxj committed
230
231
232
        import transformer_engine
        import yaml
        from transformer_engine.pytorch import module as te_module
xingjinliang's avatar
xingjinliang committed
233
234

    except ImportError:
wangxj's avatar
wangxj committed
235
236
237
238
        raise RuntimeError(
            "Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and "
            "'transformer_engine' packages"
        )
xingjinliang's avatar
xingjinliang committed
239
240
241
242

    args = get_args()

    if args.tp_comm_overlap_cfg is not None:
wangxj's avatar
wangxj committed
243
244
        with open(args.tp_comm_overlap_cfg, "r") as stream:
            ub_cfgs = yaml.safe_load(stream)
xingjinliang's avatar
xingjinliang committed
245
    else:
wangxj's avatar
wangxj committed
246
        ub_cfgs = {}
xingjinliang's avatar
xingjinliang committed
247
248

    if getattr(args, 'decoder_tp_comm_overlap', False):
wangxj's avatar
wangxj committed
249
250
251
252
        input_shape = [
            (args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size,
            args.hidden_size,
        ]
xingjinliang's avatar
xingjinliang committed
253
    else:
wangxj's avatar
wangxj committed
254
255
256
257
        input_shape = [
            (args.seq_length * args.micro_batch_size) // args.context_parallel_size,
            args.hidden_size,
        ]
xingjinliang's avatar
xingjinliang committed
258
259
260

    if is_te_min_version("1.9.0"):
        # The process group with the target bootstrap backend is created in Transformer Engine.
wangxj's avatar
wangxj committed
261
262
263
264
265
266
267
        te_module.base.initialize_ub(
            shape=input_shape,
            tp_size=args.tensor_model_parallel_size,
            use_fp8=(args.fp8 is not None),
            ub_cfgs=ub_cfgs,
            bootstrap_backend=args.tp_comm_bootstrap_backend,
        )
xingjinliang's avatar
xingjinliang committed
268
269
270
271
272
273
    else:
        if args.tp_comm_bootstrap_backend != 'mpi':
            warnings.warn(
                f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend."
            )
        # Create a MPI process group to help with TP communication overlap bootstrap.
wangxj's avatar
wangxj committed
274
275
276
277
278
279
280
281
282
        create_group(backend='mpi', group_desc='TP_BOOTSTRAP_GROUP_MPI')

        te_module.base.initialize_ub(
            shape=input_shape,
            tp_size=args.tensor_model_parallel_size,
            use_fp8=(args.fp8 is not None),
            ub_cfgs=ub_cfgs,
        )

xingjinliang's avatar
xingjinliang committed
283
284

def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
285
    """Initialize torch.distributed and core model parallel."""
Mohammad's avatar
Mohammad committed
286
287
    args = get_args()

Raul Puri's avatar
Raul Puri committed
288
    device_count = torch.cuda.device_count()
Mohammad's avatar
Mohammad committed
289
290
291
    if torch.distributed.is_initialized():

        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
292
            print(
wangxj's avatar
wangxj committed
293
                "torch distributed is already initialized, " "skipping initialization ...",
liangjing's avatar
v1  
liangjing committed
294
295
                flush=True,
            )
xingjinliang's avatar
xingjinliang committed
296
297
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()
Mohammad's avatar
Mohammad committed
298
299

    else:
wangxj's avatar
wangxj committed
300

Mohammad's avatar
Mohammad committed
301
        if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
302
            print("> initializing torch distributed ...", flush=True)
Mohammad's avatar
Mohammad committed
303
        # Manually set the device ids.
304
        if device_count > 0:
xingjinliang's avatar
xingjinliang committed
305
306
307
308
309
310
311
            torch.cuda.set_device(args.local_rank)
            device_id = torch.device(f'cuda:{args.local_rank}')
        else:
            device_id = None

        # Call the init process
        init_process_group_kwargs = {
wangxj's avatar
wangxj committed
312
            'backend': args.distributed_backend,
xingjinliang's avatar
xingjinliang committed
313
314
315
316
317
318
319
            'world_size': args.world_size,
            'rank': args.rank,
            'init_method': args.dist_url,
            'timeout': timedelta(minutes=args.distributed_timeout_minutes),
        }

        torch.distributed.init_process_group(**init_process_group_kwargs)
Mohammad's avatar
Mohammad committed
320

321
    # Set the tensor model-parallel, pipeline model-parallel, and
322
    # data-parallel communicators.
323
    if device_count > 0:
324
        if mpu.model_parallel_is_initialized():
liangjing's avatar
v1  
liangjing committed
325
            print("model parallel is already initialized")
326
        else:
liangjing's avatar
v1  
liangjing committed
327
328
329
330
331
            mpu.initialize_model_parallel(
                args.tensor_model_parallel_size,
                args.pipeline_model_parallel_size,
                args.virtual_pipeline_model_parallel_size,
                args.pipeline_model_parallel_split_rank,
wangxj's avatar
wangxj committed
332
                pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
xingjinliang's avatar
xingjinliang committed
333
334
335
336
337
338
339
                context_parallel_size=args.context_parallel_size,
                hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
                expert_model_parallel_size=args.expert_model_parallel_size,
                num_distributed_optimizer_instances=args.num_distributed_optimizer_instances,
                expert_tensor_parallel_size=args.expert_tensor_parallel_size,
                distributed_timeout_minutes=args.distributed_timeout_minutes,
                nccl_communicator_config_path=args.nccl_communicator_config_path,
wangxj's avatar
wangxj committed
340
                order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp',
xingjinliang's avatar
xingjinliang committed
341
342
343
344
                encoder_tensor_model_parallel_size=args.encoder_tensor_model_parallel_size,
                encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size,
                get_embedding_ranks=get_embedding_ranks,
                get_position_embedding_ranks=get_position_embedding_ranks,
wangxj's avatar
wangxj committed
345
                create_gloo_process_groups=args.enable_gloo_process_groups,
liangjing's avatar
v1  
liangjing committed
346
            )
347
            if args.rank == 0:
liangjing's avatar
v1  
liangjing committed
348
349
350
351
352
353
354
355
                print(
                    f"> initialized tensor model parallel with size "
                    f"{mpu.get_tensor_model_parallel_world_size()}"
                )
                print(
                    f"> initialized pipeline model parallel with size "
                    f"{mpu.get_pipeline_model_parallel_world_size()}"
                )
Mohammad's avatar
Mohammad committed
356
357
358
359
360
361
362
363
364
365
366


def _init_autoresume():
    """Set autoresume start time."""
    autoresume = get_adlr_autoresume()
    if autoresume:
        torch.distributed.barrier()
        autoresume.init()
        torch.distributed.barrier()


wangxj's avatar
wangxj committed
367
368
369
def _set_random_seed(
    seed_, data_parallel_random_init=False, te_rng_tracker=False, inference_rng_tracker=False
):
Mohammad's avatar
Mohammad committed
370
    """Set random seed for reproducability."""
371
    if seed_ is not None and seed_ > 0:
372
373
374
375
376
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
Mohammad's avatar
Mohammad committed
377
378
379
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
380
        if torch.cuda.device_count() > 0:
wangxj's avatar
wangxj committed
381
382
383
            tensor_parallel.model_parallel_cuda_manual_seed(
                seed, te_rng_tracker, inference_rng_tracker
            )
Mohammad's avatar
Mohammad committed
384
    else:
xingjinliang's avatar
xingjinliang committed
385
        raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
Mohammad's avatar
Mohammad committed
386
387


388
def write_args_to_tensorboard():
Mohammad's avatar
Mohammad committed
389
390
391
392
393
    """Write arguments to tensorboard."""
    args = get_args()
    writer = get_tensorboard_writer()
    if writer:
        for arg in vars(args):
liangjing's avatar
v1  
liangjing committed
394
            writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
395
396


397
def set_jit_fusion_options():
Sangkug Lym's avatar
Sangkug Lym committed
398
399
    """Set PyTorch JIT layer fusion options."""
    # flags required to enable jit fusion kernels
xingjinliang's avatar
xingjinliang committed
400
401
402
    if is_torch_min_version("2.2.0a0"):
        pass  # we're using torch.compile for jit fusion
    elif is_torch_min_version("1.10.0a0"):
Sangkug Lym's avatar
Sangkug Lym committed
403
404
405
406
407
408
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
wangxj's avatar
wangxj committed
409
        torch._C._jit_set_nvfuser_enabled(True)
Sangkug Lym's avatar
Sangkug Lym committed
410
411
412
413
414
415
416
        torch._C._debug_set_autodiff_subgraph_inlining(False)
    else:
        # legacy pytorch fuser
        torch._C._jit_set_profiling_mode(False)
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
417

418
419
    _warmup_jit_function()

420

421
def _warmup_jit_function():
liangjing's avatar
v1  
liangjing committed
422
    """Compilie JIT functions before the main training steps"""
423
424
    args = get_args()
    if args.bf16:
425
        dtype = torch.bfloat16
426
    elif args.fp16:
427
        dtype = torch.float16
428
    else:
429
        dtype = torch.float32
430
431

    # Warmup fused bias+gelu
liangjing's avatar
v1  
liangjing committed
432
    bias = torch.rand(
wangxj's avatar
wangxj committed
433
        args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"
liangjing's avatar
v1  
liangjing committed
434
435
436
    )
    input = torch.rand(
        (
xingjinliang's avatar
xingjinliang committed
437
            args.seq_length // args.context_parallel_size,
liangjing's avatar
v1  
liangjing committed
438
439
440
441
442
443
            args.micro_batch_size,
            args.ffn_hidden_size // args.tensor_model_parallel_size,
        ),
        dtype=dtype,
        device="cuda",
    )
444
    # Warmup JIT fusions with the input grad_enable state of both forward
445
    # prop and recomputation
446
447
    for bias_grad, input_grad in zip([True, True], [False, True]):
        bias.requires_grad, input.requires_grad = bias_grad, input_grad
448
        for _ in range(5):
xingjinliang's avatar
xingjinliang committed
449
450
451
452
            if args.swiglu:
                output = bias_swiglu(input, bias)
            else:
                output = bias_gelu(bias, input)
453
    del bias, input, output
454
455

    # Warmup fused bias+dropout+add
Vijay Korthikanti's avatar
Vijay Korthikanti committed
456
457
458
459
    if args.sequence_parallel:
        seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
    else:
        seq_length = args.seq_length
liangjing's avatar
v1  
liangjing committed
460
    input = torch.rand(
xingjinliang's avatar
xingjinliang committed
461
        (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
liangjing's avatar
v1  
liangjing committed
462
463
464
465
        dtype=dtype,
        device="cuda",
    )
    residual = torch.rand(
xingjinliang's avatar
xingjinliang committed
466
        (seq_length // args.context_parallel_size, args.micro_batch_size, args.hidden_size),
liangjing's avatar
v1  
liangjing committed
467
468
469
        dtype=dtype,
        device="cuda",
    )
wangxj's avatar
wangxj committed
470
    bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual)
471
472
    dropout_rate = 0.1
    # Warmup JIT fusions with the input grad_enable state of both forward
473
    # prop and recomputation
wangxj's avatar
wangxj committed
474
    for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
475
476
477
        input.requires_grad = input_grad
        bias.requires_grad = bias_grad
        residual.requires_grad = residual_grad
478
        for _ in range(5):
xingjinliang's avatar
xingjinliang committed
479
            output = bias_dropout_add_fused_train([input, bias], residual, dropout_rate)
480
    del bias, input, residual, output
481
    torch.cuda.empty_cache()
xingjinliang's avatar
xingjinliang committed
482
483
484


def setup_logging() -> None:
wangxj's avatar
wangxj committed
485
    """Sets the default logging level based on cmdline args and env vars.
xingjinliang's avatar
xingjinliang committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

    Precedence:
    1. Command line argument `--logging-level`
    2. Env var `MEGATRON_LOGGING_LEVEL`
    3. Default logging level (INFO)

    Returns: None
    """
    args = get_args()
    logging_level = None
    env_logging_level = os.getenv('MEGATRON_LOGGING_LEVEL', None)
    if env_logging_level is not None:
        logging_level = int(env_logging_level)
    if args.logging_level is not None:
        logging_level = args.logging_level

    if logging_level is not None:
        logger.info(f'Setting logging level to {logging_level}')
        logging.getLogger().setLevel(logging_level)