rocm.py 22.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from datetime import timedelta
6
from functools import cache, lru_cache, wraps
7
from typing import TYPE_CHECKING, Optional
8
9

import torch
10
11
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
12

13
import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils import cuda_device_count_stateless
16

17
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
18

19
if TYPE_CHECKING:
20
    from vllm.config import ModelConfig, VllmConfig
21

22
23
logger = init_logger(__name__)

24
try:
25
26
27
    from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
                        amdsmi_get_processor_handles, amdsmi_init,
                        amdsmi_shut_down, amdsmi_topo_get_link_type)
28
29
30
except ImportError as e:
    logger.warning("Failed to import from amdsmi with %r", e)

31
32
33
34
35
36
try:
    import vllm._C  # noqa: F401
except ImportError as e:
    logger.warning("Failed to import from vllm._C with %r", e)

# import custom ops, trigger op registration
zhuwenwen's avatar
zhuwenwen committed
37
38
39
40
# try:
#     import vllm._rocm_C  # noqa: F401
# except ImportError as e:
#     logger.warning("Failed to import from vllm._rocm_C with %r", e)
41

42

43
# Models not supported by ROCm.
44
_ROCM_UNSUPPORTED_MODELS: list[str] = []
45
46
47

# Models partially supported by ROCm.
# Architecture -> Reason.
zhuwenwen's avatar
zhuwenwen committed
48
49
50
51
# _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
#                     "Triton flash attention. For half-precision SWA support, "
#                     "please use CK flash attention by setting "
#                     "`VLLM_USE_TRITON_FLASH_ATTN=0`")
52
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
zhuwenwen's avatar
zhuwenwen committed
53
54
55
56
57
58
    # "Qwen2ForCausalLM":
    # _ROCM_SWA_REASON,
    # "MistralForCausalLM":
    # _ROCM_SWA_REASON,
    # "MixtralForCausalLM":
    # _ROCM_SWA_REASON,
zhuwenwen's avatar
zhuwenwen committed
59
60
61
    # "PaliGemmaForConditionalGeneration":
    # ("ROCm flash attention does not yet "
    #  "fully support 32-bit precision on PaliGemma"),
zhuwenwen's avatar
zhuwenwen committed
62
63
64
65
    # "Phi3VForCausalLM":
    # ("ROCm Triton flash attention may run into compilation errors due to "
    #  "excessive use of shared memory. If this happens, disable Triton FA "
    #  "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
66
}
67
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
68
69
70
71
72
73
74
75
    "0x74a0": "AMD_Instinct_MI300A",
    "0x74a1": "AMD_Instinct_MI300X",
    "0x74b5": "AMD_Instinct_MI300X",  # MI300X VF
    "0x74a5": "AMD_Instinct_MI325X",
    "0x74b9": "AMD_Instinct_MI325X",  # MI325X VF
    "0x74a9": "AMD_Instinct_MI300X_HF",
    "0x74bd": "AMD_Instinct_MI300X_HF",
}
76

77
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES``
zhuwenwen's avatar
zhuwenwen committed
78
79
80
81
82
83
# if "HIP_VISIBLE_DEVICES" in os.environ:
#     val = os.environ["HIP_VISIBLE_DEVICES"]
#     if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
#         assert val == cuda_val
#     else:
#         os.environ["CUDA_VISIBLE_DEVICES"] = val
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

# AMDSMI utils
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using AMDSMI is that it will not initialize CUDA


def with_amdsmi_context(fn):

    @wraps(fn)
    def wrapper(*args, **kwargs):
        amdsmi_init()
        try:
            return fn(*args, **kwargs)
        finally:
            amdsmi_shut_down()

    return wrapper


104
105
106
107
108
109
110
111
112
def device_id_to_physical_device_id(device_id: int) -> int:
    if "CUDA_VISIBLE_DEVICES" in os.environ:
        device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        physical_device_id = device_ids[device_id]
        return int(physical_device_id)
    else:
        return device_id


113
114
115
116
@cache
def on_gfx1x() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
117

118

119
@cache
120
def on_mi3xx() -> bool:
121
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
122
    return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
123
124


125
@cache
126
127
128
def on_gfx9() -> bool:
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
129
130


131
@cache
132
133
134
135
136
137
138
139
def use_rocm_custom_paged_attention(
        qtype: torch.dtype,
        head_size: int,
        block_size: int,
        gqa_ratio: int,
        max_seq_len: int,
        sliding_window: int,
        kv_cache_dtype: str,
140
141
        alibi_slopes: Optional[torch.Tensor] = None,
        sinks: Optional[torch.Tensor] = None) -> bool:
142

143
144
    GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
    ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
145
    ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
146

147
148
    # custom paged attn always supported on V0. On V1, requires sliding window
    # disabled due to observed numerical discrepancy.
zhuwenwen's avatar
zhuwenwen committed
149
150
151
152
153
154
155
    # if ON_GFX9:
    #     return ((not envs.VLLM_USE_V1 or sliding_window == 0
    #              or sliding_window == (-1, -1))
    #             and (qtype == torch.half or qtype == torch.bfloat16)
    #             and (head_size == 64 or head_size == 128)
    #             and (block_size == 16 or block_size == 32)
    #             and (gqa_ratio >= 1 and gqa_ratio <= 16)
zhuwenwen's avatar
zhuwenwen committed
156
157
    #             and max_seq_len <= 128 * 1024
    #             and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
zhuwenwen's avatar
zhuwenwen committed
158
    #             and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
159
    #                      and envs.VLLM_ROCM_USE_AITER) and sinks is None)
zhuwenwen's avatar
zhuwenwen committed
160
161
162
163
164
165
166

    # else:
    #     return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
    #                                 or sliding_window == (-1, -1))
    #             and (qtype == torch.half or qtype == torch.bfloat16)
    #             and head_size == 128 and block_size == 16
    #             and (gqa_ratio >= 3 and gqa_ratio <= 16)
zhuwenwen's avatar
zhuwenwen committed
167
    #             and max_seq_len <= 128 * 1024 and alibi_slopes is None
zhuwenwen's avatar
zhuwenwen committed
168
    #             and kv_cache_dtype == "auto"
169
    #             and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
170
    return False
171
172


173
174
class RocmPlatform(Platform):
    _enum = PlatformEnum.ROCM
175
    device_name: str = "rocm"
176
    device_type: str = "cuda"
177
    dispatch_key: str = "CUDA"
178
    ray_device_key: str = "GPU"
179
    dist_backend: str = "nccl"
180
181
    # rocm shares the same device control env var as CUDA
    device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
182

183
184
185
186
    # supported_quantization: list[str] = [
    #     "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
    #     "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "moe_wna16", "slimquant_w4a8","w8a8_int8","awq_marlin","slimquant_w4a8_marlin"
    # ]
187
    supported_quantization: list[str] = [
188
        "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
189
190
        "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao",
        "moe_wna16", "slimquant_w4a8", "w8a8_int8", "awq_marlin", "slimquant_w4a8_marlin"
191
    ]
192

193
194
195
196
197
198
199
200
201
202
203
204
    @classmethod
    def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
        if support_fa:
            if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
                    and on_gfx9()):
                # Note: AITER FA is only supported for Qwen-VL models.
                # TODO: Add support for other VL models in their model class.
                return _Backend.ROCM_AITER_FA
            if on_gfx9():
                return _Backend.FLASH_ATTN
        return _Backend.TORCH_SDPA

205
    @classmethod
206
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
207
208
                             kv_cache_dtype, block_size, use_v1, use_mla,
                             has_sink) -> str:
209
        if use_mla:
zhuwenwen's avatar
zhuwenwen committed
210
211
212
213
214
            if selected_backend == _Backend.TRITON_MLA or block_size != 64:
                if use_v1:
                    logger.info_once("Using Triton MLA backend on V1 engine.")
                    return ("vllm.v1.attention.backends.mla."
                            "triton_mla.TritonMLABackend")
215

zhuwenwen's avatar
zhuwenwen committed
216
217
                else:
                    logger.info("Using Triton MLA backend.")
zhuwenwen's avatar
zhuwenwen committed
218
219
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"   
            else:         
220
221
222
223
224
225
226
227
228
229
230
231
                if envs.VLLM_USE_FLASH_MLA:
                    from vllm.attention.backends.flashmla import (
                        is_flashmla_supported)
                    if not is_flashmla_supported()[0]:
                        logger.warning(
                            "FlashMLA backend is not supported due to %s",
                            is_flashmla_supported()[1])
                    elif block_size != 64:
                        logger.warning(
                            "FlashMLA backend is not supported for block size %d"
                            " (currently only supports block size 64).",
                            block_size)
zhuwenwen's avatar
zhuwenwen committed
232
                    else:
233
234
235
236
237
238
239
240
241
242
243
                        if use_v1:
                            logger.info_once(
                                "Using FlashMLA backend on V1 engine.")
                            return ("vllm.v1.attention.backends.mla."
                                    "flashmla.FlashMLABackend")
                        else:
                            logger.info("Using FlashMLA backend.")
                            return ("vllm.attention.backends."
                                    "flashmla.FlashMLABackend")
                else:
                    logger.info("Using Triton MLA backend (block size 64).")
zhuwenwen's avatar
zhuwenwen committed
244
                    return "vllm.attention.backends.triton_mla.TritonMLABackend"           
zhuwenwen's avatar
zhuwenwen committed
245
246
247
248
249
250
251
252
253
254
255

            # from vllm.attention.backends.rocm_aiter_mla import (
            #     is_aiter_mla_enabled)

            # if selected_backend is None:
            #     selected_backend = (_Backend.ROCM_AITER_MLA if
            #                         is_aiter_mla_enabled() or block_size == 1
            #                         else _Backend.TRITON_MLA)

            # if selected_backend == _Backend.TRITON_MLA:
            #     if block_size != 1:
zhuwenwen's avatar
zhuwenwen committed
256
257
258
259
260
261
262
263
            #         if use_v1:
            #             logger.info_once(
            #                 "Using Triton MLA backend on V1 engine.")
            #             return ("vllm.v1.attention.backends.mla."
            #                     "triton_mla.TritonMLABackend")
            #         else:
            #             logger.info("Using Triton MLA backend.")
            #             return "vllm.attention.backends.triton_mla.TritonMLABackend"  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
            #     else:
            #         raise ValueError(
            #             f" The selected backend, {selected_backend.name},"
            #             f"does not support block size {block_size}.")
zhuwenwen's avatar
zhuwenwen committed
268
269
            # elif selected_backend == _Backend.ROCM_AITER_MLA \
            #     or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
zhuwenwen's avatar
zhuwenwen committed
270
            #     if block_size == 1:
zhuwenwen's avatar
zhuwenwen committed
271
272
273
274
275
276
            #         if use_v1:
            #             logger.info("Using AITER MLA backend on V1 engine.")
            #             return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
            #         else:
            #             logger.info("Using AITER MLA backend")
            #             return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend"  # noqa: E501
zhuwenwen's avatar
zhuwenwen committed
277
278
279
280
281
282
283
284
285
            #     else:
            #         raise ValueError(
            #             f" The selected backend, {selected_backend.name},"
            #             f"does not support block size {block_size}."
            #             "(currently only supports block size 1)")
            # else:
            #     raise ValueError(
            #         f" The selected backend, {selected_backend.name},"
            #         f"is not MLA type while requested for MLA backend.")
286
287
288
      
        if selected_backend is None or selected_backend == _Backend.FLASH_ATTN:
            selected_backend = _Backend.ROCM_FLASH
289

290
        if envs.VLLM_USE_V1:
zhuwenwen's avatar
zhuwenwen committed
291
292
293
294
295
296
            TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"  # noqa: E501
            FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"  # noqa: E501
            # if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
            #     logger.info_once("Using Triton backend on V1 engine.")
            #     return TRITON_ATTN_VLLM_V1
            
297
            if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
zhuwenwen's avatar
zhuwenwen committed
298
299
300
                logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
                return FLASH_ATTN_V1

301
            else:
302
                os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
zhuwenwen's avatar
zhuwenwen committed
303
304
                logger.info_once("Using Triton backend on V1 engine.")
                return TRITON_ATTN_VLLM_V1
305
306
307
308
309
310
        
        if selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
            logger.info("Using DualChunkFlashAttention backend.")
            return ("vllm.attention.backends.dual_chunk_flash_attn."
                    "DualChunkFlashAttentionBackend")

311
312
313
314
315
316
317
318
        if selected_backend == _Backend.ROCM_FLASH:
            if not cls.has_device_capability(90):
                # not Instinct series GPUs.
                logger.info("flash_attn is not supported on NAVI GPUs.")
        else:
            logger.info("%s is not supported in AMD GPUs.", selected_backend)
        logger.info("Using ROCmFlashAttention backend.")
        return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"  # noqa: E501
319

320

321
322
323
324
325
326
327
    @classmethod
    def set_device(cls, device: torch.device) -> None:
        """
        Set the device for the current platform.
        """
        torch.cuda.set_device(device)

328
    @classmethod
329
    @lru_cache(maxsize=8)
330
331
332
    def get_device_capability(cls,
                              device_id: int = 0
                              ) -> Optional[DeviceCapability]:
333
334
        major, minor = torch.cuda.get_device_capability(device_id)
        return DeviceCapability(major=major, minor=minor)
335

336
    @classmethod
337
    @with_amdsmi_context
338
    def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
zhuwenwen's avatar
zhuwenwen committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        """
        Query if the set of gpus are fully connected by xgmi (1 hop)
        """
        handles = [
            amdsmi_get_processor_handles()[i] for i in physical_device_ids
        ]
        for i, handle in enumerate(handles):
            for j, peer_handle in enumerate(handles):
                if i < j:
                    try:
                        link_type = amdsmi_topo_get_link_type(
                            handle, peer_handle)
                        # type is 2 for XGMI
                        if link_type["hops"] != 1 or link_type["type"] != 2:
                            return False
                    except AmdSmiException as error:
                        logger.error("AMD 1 hop XGMI detection failed.",
                                     exc_info=error)
                        return False
        return True
359

360
    @classmethod
361
    @with_amdsmi_context
362
    @lru_cache(maxsize=8)
363
    def get_device_name(cls, device_id: int = 0) -> str:
364
        physical_device_id = device_id_to_physical_device_id(device_id)
365
        handle = amdsmi_get_processor_handles()[physical_device_id]
366
367
        # return amdsmi_get_gpu_asic_info(handle)["market_name"]
        return torch.cuda.get_device_name(device_id)
368

zhuwenwen's avatar
zhuwenwen committed
369
370
371
372
    @classmethod
    def get_device_total_memory(cls, device_id: int = 0) -> int:
        device_props = torch.cuda.get_device_properties(device_id)
        return device_props.total_memory
373

374
375
    @classmethod
    def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
376
        if enforce_eager and not envs.VLLM_USE_V1:
377
378
379
380
381
382
383
            logger.warning(
                "To see benefits of async output processing, enable CUDA "
                "graph. Since, enforce-eager is enabled, async output "
                "processor cannot be used")
            return False
        return True

384
    @classmethod
385
    def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
386
387
        from vllm.config.compilation import CUDAGraphMode

388
        cache_config = vllm_config.cache_config
389
390
391
392
393
394
395
396
        compilation_config = vllm_config.compilation_config
        parallel_config = vllm_config.parallel_config
        is_eager_execution = compilation_config == CUDAGraphMode.NONE

        use_v1 = envs.VLLM_USE_V1
        use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
             envs.VLLM_ROCM_USE_AITER_RMSNORM

397
398
399
        if cache_config and cache_config.block_size is None:
            cache_config.block_size = 16

400
        if parallel_config.worker_cls == "auto":
401
            if vllm_config.speculative_config:
402
                if not use_v1:
403
                    raise NotImplementedError(
404
405
                        "Speculative decoding is not supported on vLLM V0.")
                parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
406
            else:
407
                if use_v1:
408
                    parallel_config.worker_cls = \
409
                        "vllm.v1.worker.gpu_worker.Worker"
410
411
                else:
                    parallel_config.worker_cls = "vllm.worker.worker.Worker"
412
413
414
        #  Aiter rms norm perform best when CUDA Graph capture is enabled.
        if use_v1 and use_aiter_rms_norm and not is_eager_execution:
            compilation_config.custom_ops.append("+rms_norm")
415

416
417
418
419
420
421
422
423
424
425
426
427
    @classmethod
    def verify_model_arch(cls, model_arch: str) -> None:
        if model_arch in _ROCM_UNSUPPORTED_MODELS:
            raise ValueError(f"Model architecture '{model_arch}' is not "
                             "supported by ROCm for now.")

        if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
            msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
            logger.warning(
                "Model architecture '%s' is partially "
                "supported by ROCm: %s", model_arch, msg)

428
429
430
431
432
433
    @classmethod
    def verify_quantization(cls, quant: str) -> None:
        super().verify_quantization(quant)
        if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
            logger.warning(
                "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
434
435
                " is not set, disabling VLLM_USE_TRITON_AWQ.")
            envs.VLLM_USE_TRITON_AWQ = False
436
437
438
439

    @classmethod
    def get_punica_wrapper(cls) -> str:
        return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
440
441
442
443
444
445

    @classmethod
    def get_current_memory_usage(cls,
                                 device: Optional[torch.types.Device] = None
                                 ) -> float:
        torch.cuda.reset_peak_memory_stats(device)
zhuwenwen's avatar
zhuwenwen committed
446
447
448
        # return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
        #     device)[0]
        return torch.cuda.max_memory_allocated(device)
449
450
451
452

    @classmethod
    def get_device_communicator_cls(cls) -> str:
        return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"  # noqa
453

454
455
456
457
458
    @classmethod
    def supports_mx(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ["gfx95"])

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    @classmethod
    def supports_fp8(cls) -> bool:
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
        return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])

    @classmethod
    def is_fp8_fnuz(cls) -> bool:
        # only device 0 is checked, this assumes MI300 platforms are homogeneous
        return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName

    @classmethod
    def fp8_dtype(cls) -> torch.dtype:
        if cls.is_fp8_fnuz():
            return torch.float8_e4m3fnuz
        else:
            return torch.float8_e4m3fn
475
476

    @classmethod
477
    def supports_v1(cls, model_config: "ModelConfig") -> bool:
478
479
        # V1 support on AMD gpus is experimental
        return True
480
481
482
483
484

    @classmethod
    def use_custom_allreduce(cls) -> bool:
        # We only enable custom allreduce for MI300 series
        gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
485
        supported_archs = ['gfx94', 'gfx95']
486
        return any(gfx in gcn_arch for gfx in supported_archs)
487

488
489
490
491
    @classmethod
    def opaque_attention_op(cls) -> bool:
        return True

492
493
494
    @classmethod
    def get_cu_count(cls, device_id: int = 0) -> int:
        return torch.cuda.get_device_properties(
495
            device_id).multi_processor_count
496
497
498
499

    @classmethod
    def is_navi(cls) -> bool:
        return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
500
501

    @classmethod
502
503
    def get_static_graph_wrapper_cls(cls) -> str:
        return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

    @classmethod
    def stateless_init_device_torch_dist_pg(
        cls,
        backend: str,
        prefix_store: PrefixStore,
        group_rank: int,
        group_size: int,
        timeout: timedelta,
    ) -> ProcessGroup:
        assert is_nccl_available()
        pg: ProcessGroup = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
        from torch.distributed.distributed_c10d import ProcessGroupNCCL

        backend_options = ProcessGroupNCCL.Options()
        backend_options._timeout = timeout

        backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
                                         backend_options)
        backend_type = ProcessGroup.BackendType.NCCL
        device = torch.device("cuda")
        pg._set_default_backend(backend_type)
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
        return pg
534
535
536
537

    @classmethod
    def device_count(cls) -> int:
        return cuda_device_count_stateless()
538
539

    @classmethod
540
541
    def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
                                    model_config: "ModelConfig") -> bool:
542
        return True
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

    @classmethod
    def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
        if torch_dtype == torch.bfloat16:  # noqa: SIM102
            if not cls.has_device_capability(80):
                capability = cls.get_device_capability()
                gpu_name = cls.get_device_name()

                if capability is None:
                    compute_str = "does not have a compute capability"
                else:
                    version_str = capability.as_version_str()
                    compute_str = f"has compute capability {version_str}"

                raise ValueError(
                    "Bfloat16 is only supported on GPUs "
                    "with compute capability of at least 8.0. "
                    f"Your {gpu_name} GPU {compute_str}. "
                    "You can use float16 instead by explicitly setting the "
                    "`dtype` flag in CLI, for example: --dtype=half.")
563
564
565
566

    @classmethod
    def support_hybrid_kv_cache(cls) -> bool:
        return True