utils.py 26.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Utility functions for Transformer Engine modules"""
6
from __future__ import annotations
7
import functools
8
import math
9
import os
10
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
11
import numpy as np
Przemek Tredak's avatar
Przemek Tredak committed
12
import torch
wenjh's avatar
wenjh committed
13
14
15
16
17
18
import warnings
try:
    import lightop
    enable_lightop = True
except ImportError:
    enable_lightop = False
19
import transformer_engine.pytorch.cpp_extensions as ext
20
from . import torch_version
yuguo's avatar
yuguo committed
21
from torch.utils.cpp_extension import IS_HIP_EXTENSION
Przemek Tredak's avatar
Przemek Tredak committed
22

23
24
25
26
27
28
29
30
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
    """Check if any of the given tensors require gradient."""
    for tensor in tensors:
        if tensor is not None and tensor.requires_grad:
            return True
    return False


31
32
33
34
35
36
@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
    """Get tensor with no entries and no data"""
    return torch.Tensor().cuda()


37
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
38
39
40
41
42
43
    """
    Trick to deallocate tensor memory when delete operation does not
    release the tensor due to PyTorch override.

    Must be used carefully.
    """
44

45
    for t in tensors:
46
        if t is not None:
47
48
49
50
51
52
53
            # Workaround for double buffering in cpu offload
            if hasattr(t, "do_not_clear"):
                continue
            if hasattr(t, "get_data_tensors"):
                if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()):
                    continue

54
            if hasattr(t, "clear"):
55
                t.clear()
56
            else:
57
                t.data = _empty_tensor()
58
            del t
59
60


61
62
63
64
65
66
@functools.lru_cache
def _get_device_compute_capability(device: torch.device) -> Tuple[int, int]:
    props = torch.cuda.get_device_properties(device)
    return (props.major, props.minor)


Tim Moon's avatar
Tim Moon committed
67
68
def get_device_compute_capability() -> Tuple[int, int]:
    """CUDA compute capability of current GPU"""
69
    return _get_device_compute_capability(torch.cuda.current_device())
70
71


Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
75
76
77
78
79
80
81
82
83
84
def attention_mask_func(
    attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
    """Get attention mask"""
    attention_scores.masked_fill_(attention_mask, -10000.0)
    return attention_scores


def get_default_init_method() -> Callable:
    """Weight initialization method if not provided by user"""
    return init_method_normal(0.023)


85
86
87
def init_method_constant(val: float) -> Callable:
    """Init method to set all tensor elements to a constant value."""
    if val == 1.0:
88

89
90
        def init_(tensor: torch.Tensor) -> Callable:
            return torch.nn.init.ones_(tensor)
91

92
    elif val == 0.0:
93

94
95
        def init_(tensor: torch.Tensor) -> Callable:
            return torch.nn.init.zeros_(tensor)
96

97
    else:
98

99
100
101
102
103
104
        def init_(tensor: torch.Tensor) -> Callable:
            return torch.nn.init.constant_(tensor, val)

    return init_


Przemek Tredak's avatar
Przemek Tredak committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def init_method_normal(sigma: float) -> Callable:
    """Init method based on N(0, sigma)."""

    def init_(tensor: torch.Tensor) -> Callable:
        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

    return init_


def scaled_init_method_normal(sigma: float, num_layers: int) -> Callable:
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)

    def init_(tensor: torch.Tensor) -> Callable:
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_


def all_close(a: torch.Tensor, b: torch.Tensor) -> bool:
    """torch.allclose with cpu to not run into OOMs"""
    return torch.allclose(a.cpu(), b.cpu())


def print_rank_0(*args: Any) -> None:
    """print on rank 0"""
    if torch.cuda.current_device() == 0:
        print(*args)


def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None:
    """util function to show some tensor stats"""
    if a.shape != b.shape:
        print_rank_0("Tensors have different shape")
        return
    print_rank_0(a)
    print_rank_0(b)
    max_err = torch.max(torch.abs(a - b))
    max_a = torch.max(a)
    max_b = torch.max(b)
    print_rank_0(f"max err={max_err}, max a={max_a}, max_b={max_b}")


def ensure_divisibility(numerator: int, denominator: int) -> None:
    """Ensure that numerator is divisible by the denominator."""
150
    assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
Przemek Tredak's avatar
Przemek Tredak committed
151
152
153
154
155
156
157
158
159


def divide(numerator: int, denominator: int) -> int:
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator


160
161
def split_tensor_along_dim(
    tensor: torch.Tensor, dim: int, num_partitions: int, contiguous_split_chunks: bool = False
Przemek Tredak's avatar
Przemek Tredak committed
162
163
164
165
166
167
168
169
170
) -> Tuple[torch.Tensor, ...]:
    """Split a tensor along its last dimension.
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
    """
    # Get the size and dimension.
171
    split_size = divide(tensor.size()[dim], num_partitions)
Przemek Tredak's avatar
Przemek Tredak committed
172
    # Split.
173
    tensor_list = torch.split(tensor, split_size, dim=dim)
Przemek Tredak's avatar
Przemek Tredak committed
174
175
176
177
178
179
180
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list


181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
# @klakhani TODO: Consider combining with split_tensor_along_dim() and no_op_cat() and SplitAlongDim
def combine_tensors(
    tensors: List[torch.Tensor],
    dim: int,
) -> torch.Tensor:
    """Combine tensors along a particular dimension"""

    num_tensors = len(tensors)
    new_shape = list(tensors[0].shape)
    new_shape.insert(dim, num_tensors)
    from transformer_engine.pytorch.float8_tensor import Float8Tensor

    if isinstance(tensors[0], Float8Tensor):
        new_stride = list(tensors[0]._data.stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
        combined_tensor.set_(
            tensors[0]._data.untyped_storage(),
            tensors[0]._data.storage_offset(),
            new_shape,
            new_stride,
        )
        combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
    else:
        new_stride = list(tensors[0].stride())
        new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
        combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
        combined_tensor.set_(
            tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
        )

    return combined_tensor


class SplitAlongDim(torch.autograd.Function):
    """
    Split tensor along given dimension
    """

    @staticmethod
    def forward(
        ctx,
        mixed_x_layer: torch.Tensor,
        split_dim: int,
        split_size_or_sections: Union[int, List[int], Tuple[int]],
        squeeze=False,
    ) -> Tuple[torch.Tensor, ...]:
        # pylint: disable=missing-function-docstring
        ctx.split_dim = split_dim
        ctx.split_size_or_sections = split_size_or_sections
        from transformer_engine.pytorch.float8_tensor import Float8Tensor
        from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase

        if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
            mixed_x_layer, Float8Tensor
        ):
            return tuple(
                Float8TensorBase(
                    fp8_scale_inv=mixed_x_layer._scale_inv,
                    fp8_dtype=mixed_x_layer._fp8_dtype,
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
                    quantizer=mixed_x_layer._quantizer,
                )
                for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim,
                )
            )
        if isinstance(mixed_x_layer, Float8Tensor):
            return tuple(
                Float8Tensor.make_like(
                    mixed_x_layer,
                    data=x.squeeze(split_dim) if squeeze else x,
                    shape=x.squeeze(split_dim).shape if squeeze else x.shape,
                )
                for x in torch.split(
                    mixed_x_layer._data,
                    split_size_or_sections=split_size_or_sections,
                    dim=split_dim,
                )
            )
        out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
        if squeeze:
            out_list = [x.squeeze(split_dim) for x in out_list]
        return out_list

    @staticmethod
    def backward(ctx, *grad_outputs):
        # pylint: disable=missing-function-docstring
        assert len(grad_outputs) > 0, "No gradients received for backprop!"

        if isinstance(ctx.split_size_or_sections, (list, tuple)):
            split_sizes = ctx.split_size_or_sections
            assert len(grad_outputs) == len(
                split_sizes
            ), "Unequal number of gradients vs split sections for backprop!"
        if isinstance(ctx.split_size_or_sections, int):
            split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
        dims = len(grad_outputs[0].shape)
        split_dim = (ctx.split_dim + dims) % dims
        from transformer_engine.pytorch.float8_tensor import Float8Tensor

        if isinstance(grad_outputs[0], Float8Tensor):
            noop_ok = True
            strides = grad_outputs[0].stride()
            data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
            shape = list(grad_outputs[0].shape)
            for i, tensor in enumerate(grad_outputs):
                shape_i = shape
                shape_i[split_dim] = split_sizes[i]
                offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
                if (
                    tensor.stride() != strides
                    or list(tensor.shape) != shape_i
                    or tensor._data.untyped_storage().data_ptr() != data_ptr
                    or tensor.storage_offset() != offset_size
                ):
                    noop_ok = False
                    break
            if noop_ok:
                ret = torch.Tensor().to(
                    device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
                )
                new_shape = list(shape)
                new_shape[split_dim] = sum(split_sizes)
                ret.set_(
                    grad_outputs[0]._data.untyped_storage(),
                    grad_outputs[0]._data.storage_offset(),
                    new_shape,
                    strides,
                )
                return (
                    Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
                    None,
                    None,
                )

            grad_outputs_data = [x._data for x in grad_outputs]
            data = torch.cat(grad_outputs_data, dim=split_dim)
            return (
                Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
                None,
                None,
                None,
            )
        noop_ok = True
        strides = grad_outputs[0].stride()
        data_ptr = grad_outputs[0].untyped_storage().data_ptr()
        shape = list(grad_outputs[0].shape)
        for i, tensor in enumerate(grad_outputs):
            shape_i = shape
            shape_i[split_dim] = split_sizes[i]
            offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
            if (
                tensor.stride() != strides
                or list(tensor.shape) != shape_i
                or tensor.untyped_storage().data_ptr() != data_ptr
                or tensor.storage_offset() != offset_size
            ):
                noop_ok = False
                break
        if noop_ok:
            ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
            new_shape = list(shape)
            new_shape[split_dim] = sum(split_sizes)
            ret.set_(
                grad_outputs[0].untyped_storage(),
                grad_outputs[0].storage_offset(),
                new_shape,
                strides,
            )
            return ret, None, None

        return torch.cat(grad_outputs, dim=split_dim), None, None


Przemek Tredak's avatar
Przemek Tredak committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def validate_ctx_manager(ctx: Callable) -> None:
    """Checks if passed in object can be used as a context manager."""
    try:
        with ctx():
            pass
    except Exception as e:
        raise ValueError("Object must be a valid ctx manager") from e


def validate_rng_states_func(get_rng_tracker: Callable) -> None:
    """Checks if passed in param function has everything
    required for tensor/model and sequence parallel.
    """
    assert callable(get_rng_tracker), "get_rng_tracker is not a valid function"

    rng_tracker = None
    try:
        rng_tracker = get_rng_tracker()
    except Exception as e:
        raise RuntimeError("Cannot call get_rng_tracker function") from e

    assert hasattr(rng_tracker, "get_states") and callable(
        rng_tracker.get_states
    ), "rng_tracker object does not have valid method get_states"
    assert hasattr(rng_tracker, "set_states") and callable(
        rng_tracker.set_states
    ), "rng_tracker object does not have valid method set_states"
    assert hasattr(rng_tracker, "fork") and callable(
        rng_tracker.fork
    ), "rng_tracker object does not have valid method fork"
    validate_ctx_manager(rng_tracker.fork)


392
def assert_viewless_tensor(tensor: torch.Tensor, extra_msg: Optional[str] = None) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
393
394
395
396
397
398
399
    """Assert that a tensor is not a view (i.e., its '._base' field is
    not set)."""
    if isinstance(tensor, list):
        return [assert_viewless_tensor(t) for t in tensor]
    if not isinstance(tensor, torch.Tensor):
        return tensor
    assert tensor._base is None, (
400
401
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
Przemek Tredak's avatar
Przemek Tredak committed
402
403
404
405
406
        f"likely accumulate over iterations). {extra_msg}"
    )
    return tensor


407
def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch.Tensor) -> None:
Przemek Tredak's avatar
Przemek Tredak committed
408
409
410
411
412
413
    """Safely set tensor's '.data' field.

    Check first that the tensor is viewless (i.e., '._base' not set). If not,
    raise an exception.
    """
    extra_msg = (
414
        "FYI, tensor._base has shape "
Przemek Tredak's avatar
Przemek Tredak committed
415
416
417
418
419
420
421
422
423
        f"{'--' if tensor._base is None else tensor._base.shape},"
        f"and new_data_tensor has shape {new_data_tensor.shape}."
    )
    assert_viewless_tensor(tensor, extra_msg=extra_msg)
    tensor.data = new_data_tensor


def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
    """Cast tensor to dtype"""
424
425
426
427
    if tensor is None:
        return None
    if tensor.dtype == dtype:
        return tensor
428
    with torch.enable_grad():
429
        return tensor.to(dtype=dtype)
430
431


432
def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool:
433
    """Check if tensor dimensions are supported for FP8 TN GEMM"""
434
    return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0
435
436


437
438
439
440
def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
    """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM."""

    for tensor in tensors:
441
442
443
444
        assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, (
            "FP8 execution requires the product of all dimensions except the last to be divisible"
            " by 8 and the last dimension to be divisible by 16, but got tensor with"
            f" dims={list(tensor.size())}"
445
        )
446

yuguo's avatar
yuguo committed
447
448
449
450
451
452
453
454
455
456
457
if IS_HIP_EXTENSION:
    def is_mi200():
      """check whether this machine is mi200/210/250"""
      import re
      return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
    
    def is_K100_AI():
      """check whether this machine is K100_AI"""
      import re
      return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)

458
    def is_BW():
yuguo's avatar
yuguo committed
459
460
461
      """check whether this machine is BW"""
      import re
      return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
wenjh's avatar
wenjh committed
462
463
464
    
    def use_lightop_w8a8(block_size: List[int]) -> bool:
        """Check whether to use lightop for w8a8"""
wenjh's avatar
wenjh committed
465
466
        # Just return False because lightop is not ready now.
        return False
wenjh's avatar
wenjh committed
467
468
469
470
471
472
473
474
        if(enable_lightop):
            return get_device_compute_capability() >= (9, 3) and block_size[1] == 128
        else:
            if(get_device_compute_capability() >= (9, 3) and block_size[1] == 128):
                warnings.warn(
                    "Lightop is not available. Using default implementation for w8a8."
                )
            return False
475

476
477
def is_bf16_compatible() -> None:
    """Replaces torch.cuda.is_bf16_compatible() with an explicit
478
    check on device compute capability to enforce sm_80 or higher.
479
    """
yuguo's avatar
yuguo committed
480
481
    if IS_HIP_EXTENSION:
        # only MI200 and MI300 machines support bf16
482
        if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW():
yuguo's avatar
yuguo committed
483
484
485
486
487
            return True
        else:
            return False
    else:
        return torch.cuda.get_device_capability()[0] >= 8
488
489


490
@functools.lru_cache(maxsize=None)
yuguo's avatar
yuguo committed
491
def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
492
493
494
    """Checks whether the device supports
    non-TN layouts for FP8 GEMMs.
    """
yuguo's avatar
yuguo committed
495
    if IS_HIP_EXTENSION:
yuguo's avatar
yuguo committed
496
497
498
499
        if is_blockwise:
            return False
        else:
            return True
500
501
    device_capability = torch.cuda.get_device_capability()
    return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
502
503


504
@functools.lru_cache(maxsize=None)
505
506
def get_cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
yuguo's avatar
yuguo committed
507
508
509
    # ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
    if IS_HIP_EXTENSION:
        return (99, 0, 0)
510
511
512
513
514
    encoded_version = ext.get_cudnn_version()
    major_version_magnitude = 1000 if encoded_version < 90000 else 10000
    major, encoded_version = divmod(encoded_version, major_version_magnitude)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563


def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
    """Canonicalize PyTorch device

    If `None`, then returns the default CUDA device.

    """
    if device is None:
        # Use default CUDA device
        device = torch.get_default_device()
        if device.type != "cuda":
            device = torch.device("cuda", torch.cuda.current_device())
    elif not isinstance(device, torch.device):
        device = torch.device(device)
    if device.type == "cuda" and device.index is None:
        device = torch.device("cuda", torch.cuda.current_device())
    return device


def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
    """Canonicalize PyTorch datatype

    If `None`, then returns the default PyTorch datatype.

    """
    if dtype is None:
        # Use default dtype
        dtype = torch.get_default_dtype()
    return dtype


def devices_match(device1: torch.device, device2: torch.device) -> bool:
    """Whether two devices are the same"""
    device1 = torch.device(device1)
    device2 = torch.device(device2)
    if device1.type != device2.type:
        return False
    if device1.type == "cuda":
        index1 = device1.index
        index2 = device2.index
        if index1 == index2:
            return True
        if index1 is None:
            index1 = torch.cuda.current_device()
        if index2 is None:
            index2 = torch.cuda.current_device()
        return index1 == index2
    return device1 == device2
564
565
566
567
568
569
570
571
572
573
574
575
576


@functools.lru_cache
def get_sm_count() -> int:
    """Returns the number of streaming multiprocessors in the current device."""
    return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count


def round_up_to_nearest_multiple(value, multiple):
    """Round up `value` to the next mutiple of `multiple`"""
    if multiple == 0:
        raise ValueError("multiple cannot be zero.")
    return ((value + multiple - 1) // multiple) * multiple
577
578


579
580
def needs_quantized_gemm(obj, rowwise=True):
    """Used to check if obj will need quantized gemm or normal gemm."""
wenjh's avatar
wenjh committed
581
    from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
582
583
584
585
586
587
588
589
590
591
592
    if isinstance(obj, DebugQuantizedTensor):
        return type(obj.get_tensor(not rowwise)) not in [  # pylint: disable=unidiomatic-typecheck
            torch.Tensor,
            torch.nn.Parameter,
        ]
    return type(obj) not in [
        torch.Tensor,
        torch.nn.Parameter,
    ]  # pylint: disable=unidiomatic-typecheck


593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
@functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool:
    """Check if NVTX range profiling is enabled"""
    return bool(int(os.getenv("NVTE_NVTX_ENABLED", "0")))


# Messages associated with active NVTX ranges
_nvtx_range_messages: list[str] = []


def nvtx_range_push(msg: str) -> None:
    """Push NVTX range onto stack, if NVTX range profiling is enabled

    Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range
    profiling.

    Parameters
    ----------
    msg: str
        Message to associate with range

    """
    if not _nvtx_enabled():
        return
    _nvtx_range_messages.append(msg)
    torch.cuda.nvtx.range_push(msg)


def nvtx_range_pop(msg: Optional[str] = None) -> None:
    """Pop NVTX range from stack, if NVTX range profiling is enabled

    Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range
    profiling.

    Parameters
    ----------
    msg: str, optional
        Message associated with range

    """

    # Return immediately if NVTX range profiling is not enabled
    if not _nvtx_enabled():
        return

    # Update list of NVTX range messages and check for consistency
    if not _nvtx_range_messages:
        raise RuntimeError("Attempted to pop NVTX range from empty stack")
    last_msg = _nvtx_range_messages.pop()
    if msg is not None and msg != last_msg:
        raise ValueError(
            f"Attempted to pop NVTX range from stack with msg={msg}, "
            f"but last range has msg={last_msg}"
        )

    # Pop NVTX range
    torch.cuda.nvtx.range_pop()
650
651
652
653
654
655
656
657
658
659
660
661
662


def canonicalize_process_group(
    group: Optional[torch.distributed.ProcessGroup],
) -> torch.distributed.ProcessGroup:
    """Convert to PyTorch process group

    If `None`, returns default process group.

    """
    if group is None:
        return torch.distributed.distributed_c10d._get_default_group()
    return group
663
664
665
666
667
668
669
670
671
672
673
674
675


def torch_get_autocast_gpu_dtype() -> torch.dtype:
    """Get PyTorch autocast GPU dtype."""
    if torch_version() >= (2, 4, 0):
        return torch.get_autocast_dtype("cuda")
    return torch.get_autocast_gpu_dtype()


if torch_version() >= (2, 4, 0):
    gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
    gpu_autocast_ctx = torch.cuda.amp.autocast
wenjh's avatar
wenjh committed
676
677

from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784

_torch_dtype_to_np_typestr_dict = {
    torch.float16: "<f2",
    torch.float32: "<f4",
    torch.int64: "<i8",
    torch.int32: "<i4",
    torch.int8: "|i1",
    torch.float8_e4m3fn: "|i1",
    torch.qint8: "|u1",
    torch.bool: "|b1",
    torch.bfloat16: "<f2",
}


class _WeakRefTensor:
    """
    A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
    """

    def __init__(
        self,
        data_ptr: int,
        dtype: torch.dtype,
        shape: Sequence[int],
    ):
        self._data_ptr = data_ptr
        self.dtype = dtype
        self.shape = shape

    def data_ptr(self):
        """Data pointer of the tensor."""
        return self._data_ptr

    @property
    def dtype(self):
        """Dtype of the tensor."""
        return self._dtype

    @property
    def shape(self):
        """Shape of the tensor."""
        return getattr(self, "_shape", None)

    @dtype.setter
    def dtype(self, dtype: torch.dtype):
        self._dtype = dtype

    @shape.setter
    def shape(self, shape: Sequence[int]):
        self._shape = tuple(int(i) for i in shape)

    def numel(self):
        """Number of elements in the tensor."""
        return np.prod(self.shape)

    @property
    def __cuda_array_interface__(self):
        return {
            "shape": self.shape,
            "typestr": self.torch_dtype_to_np_typestr(),
            "data": (self.data_ptr() if self.numel() > 0 else 0, False),
            "version": 3,
        }

    def torch_dtype_to_np_typestr(self):
        """Convert PyTorch dtype to numpy typestr."""
        ret = _torch_dtype_to_np_typestr_dict.get(self.dtype)
        assert ret is not None, f"Unsupported dtype: {self.dtype}"
        return ret


def make_weak_ref(x):
    """
    This function is to make a weak reference to the input so that the memory can be released.
    """

    def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torch.Tensor:
        """
        This function is to convert the `_WeakRefTensor` to torch.Tensor.
        """
        if isinstance(tensor, torch.Tensor):
            return tensor

        old_ptr = tensor.data_ptr()
        new_tensor = torch.as_tensor(tensor).view(tensor.dtype)
        new_ptr = new_tensor.data_ptr()
        if old_ptr != new_ptr:
            raise RuntimeError("Data pointer mismatch after converting to torch.Tensor")
        return new_tensor

    if isinstance(x, torch.Tensor):
        return (
            convert_to_torch_tensor(_WeakRefTensor(x.data_ptr(), x.dtype, x.shape))
            if x.is_cuda
            else x
        )
    if isinstance(x, tuple):
        return tuple(make_weak_ref(i) for i in x)
    if isinstance(x, list):
        return [make_weak_ref(i) for i in x]
    if isinstance(x, dict):
        return {k: make_weak_ref(v) for k, v in x.items()}
    if isinstance(x, (int, float, bool)):
        return x
    if x is None:
        return None
    raise TypeError(f"Invalid type {type(x)} to make weak ref")