"src/diffusers/models/unets/unet_2d_blocks.py" did not exist on "f024e00398344fa4d24fdbebefa6d4d2aa98589b"
spatial.py 1.77 KB
Newer Older
1
2
3
import torch
from torch.cuda.streams import ExternalStream

4
5
6
7
8
9
10
11
12
13
14
try:
    from . import spatial_ops  # triggers TORCH extension registration
except Exception as _e:
    _spatial_import_error = _e
else:
    _spatial_import_error = None

_IMPORT_ERROR = ImportError(
    "Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)

15
16
17
18
19
20
21
22
23
24
25
26
27

def create_greenctx_stream_by_value(
    SM_a: int, SM_b: int, device_id: int = None
) -> tuple[ExternalStream, ExternalStream]:
    """
    Create two streams for greenctx.
    Args:
        sm_A (int): The SM of stream A.
        sm_B (int): The weight of stream B.
        device_id (int): The device id.
    Returns:
        tuple[ExternalStream, ExternalStream]: The two streams.
    """
28
29
    if _spatial_import_error is not None:
        raise _IMPORT_ERROR from _spatial_import_error
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    if device_id is None:
        device_id = torch.cuda.current_device()

    res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)

    stream_a = ExternalStream(
        stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
    )
    stream_b = ExternalStream(
        stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
    )

    return stream_a, stream_b


def get_sm_available(device_id: int = None) -> int:
    """
    Get the SMs available on the device.
    Args:
        device_id (int): The device id.
    Returns:
        int: The SMs available.
    """
53
54
    if _spatial_import_error is not None:
        raise _IMPORT_ERROR from _spatial_import_error
55
56
57
58
59
60
61
62
63
    if device_id is None:
        device_id = torch.cuda.current_device()

    device_props = torch.cuda.get_device_properties(device_id)

    # Get the number of Streaming Multiprocessors (SMs)
    sm_count = device_props.multi_processor_count

    return sm_count