parallel.py 25.5 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
5
import os
6
from typing import TYPE_CHECKING, Any, Literal
7
8

import torch
9
from pydantic import Field, model_validator
10
11
12
13
14
15
16
17
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self

import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
18
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
19
20
21
22
23
24
25
26
27
28
29
30
31

if TYPE_CHECKING:
    from ray.runtime_env import RuntimeEnv
    from ray.util.placement_group import PlacementGroup

    from vllm.executor.executor_base import ExecutorBase
else:
    RuntimeEnv = Any
    PlacementGroup = Any
    ExecutorBase = Any

logger = init_logger(__name__)

32
ExpertPlacementStrategy = Literal["linear", "round_robin"]
33
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
34
DataParallelBackend = Literal["ray", "mp"]
35
36


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@config
@dataclass
class EPLBConfig:
    """Configuration for Expert Parallel Load Balancing (EP)."""

    window_size: int = 1000
    """Window size for expert load recording."""
    step_interval: int = 3000
    """
    Interval for rearranging experts in expert parallelism.

    Note that if this is greater than the EPLB window size, only the metrics
    of the last `lb_window_size` steps will be used for rearranging experts.
    """

52
    num_redundant_experts: int = Field(default=0, ge=0)
53
54
55
56
57
58
59
60
61
    """Number of redundant experts to use for expert parallelism."""

    log_balancedness: bool = False
    """
    Log the balancedness each step of expert parallelism.
    This is turned off by default since it will cause communication overhead.
    """


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@config
@dataclass
class ParallelConfig:
    """Configuration for the distributed execution."""

    pipeline_parallel_size: int = 1
    """Number of pipeline parallel groups."""
    tensor_parallel_size: int = 1
    """Number of tensor parallel groups."""
    data_parallel_size: int = 1
    """Number of data parallel groups. MoE layers will be sharded according to
    the product of the tensor parallel size and data parallel size."""
    data_parallel_size_local: int = 1
    """Number of local data parallel groups."""
    data_parallel_rank: int = 0
    """Rank of the data parallel group."""
78
    data_parallel_rank_local: int | None = None
79
80
81
82
83
84
85
86
    """Local rank of the data parallel group,
    set only in SPMD mode."""
    data_parallel_master_ip: str = "127.0.0.1"
    """IP of the data parallel master."""
    data_parallel_rpc_port: int = 29550
    """Port for data parallel messaging."""
    data_parallel_master_port: int = 29500
    """Port of the data parallel master."""
87
    data_parallel_backend: DataParallelBackend = "mp"
88
89
90
91
    """Backend to use for data parallel, either "mp" or "ray"."""
    data_parallel_external_lb: bool = False
    """Whether to use "external" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
co63oc's avatar
co63oc committed
92
    wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
93
94
95
96
97
98
99
100
101
102
103
104
    is provided explicitly to vllm serve."""
    data_parallel_hybrid_lb: bool = False
    """Whether to use "hybrid" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. Enables running an AsyncLLM
    and API server on a "per-node" basis where vLLM load balances
    between local data parallel ranks, but an external LB balances
    between vLLM nodes/replicas. Set explicitly in conjunction with
    --data-parallel-start-rank."""
    enable_expert_parallel: bool = False
    """Use expert parallelism instead of tensor parallelism for MoE layers."""
    enable_eplb: bool = False
    """Enable expert parallelism load balancing for MoE layers."""
105
    eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
106
    """Expert parallelism configuration."""
107
108
109
110
111
112
113
114
115
    expert_placement_strategy: ExpertPlacementStrategy = "linear"
    """The expert placement strategy for MoE layers:\n
    - "linear": Experts are placed in a contiguous manner. For example, with 4
      experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
      experts [2, 3].\n
    - "round_robin": Experts are placed in a round-robin manner. For example,
      with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
      will have experts [1, 3]. This strategy can help improve load balancing
      for grouped expert models with no redundant experts."""
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    all2all_backend: (
        Literal[
            "naive",
            "pplx",
            "deepep_high_throughput",
            "deepep_low_latency",
            "allgather_reducescatter",
            "flashinfer_all2allv",
        ]
        | None
    ) = None
    """All2All backend for MoE expert parallel communication. If not set, uses
    the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
    - "naive": Naive all2all implementation using broadcasts
    - "allgather_reducescatter": All2all based on allgather and reducescatter
    - "pplx": Use pplx kernels
    - "deepep_high_throughput": Use deepep high-throughput kernels
    - "deepep_low_latency": Use deepep low-latency kernels
    - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
135
    num_redundant_experts: int | None = None
136
137
138
    """`num_redundant_experts` is deprecated and has been replaced with
    `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
    Please use `eplb_config.num_redundant_experts` instead."""
139
    eplb_window_size: int | None = None
140
141
142
    """`eplb_window_size` is deprecated and has been replaced with
    `eplb_config.window_size`. This will be removed in v0.12.0.
    Please use `eplb_config.window_size` instead."""
143
    eplb_step_interval: int | None = None
144
145
146
    """`eplb_step_interval` is deprecated and has been replaced with
    `eplb_config.step_interval`. This will be removed in v0.12.0.
    Please use `eplb_config.step_interval` instead."""
147
    eplb_log_balancedness: bool | None = None
148
149
150
    """`eplb_log_balancedness` is deprecated and has been replaced with
    `eplb_config.log_balancedness`. This will be removed in v0.12.0.
    Please use `eplb_config.log_balancedness` instead."""
151

152
    max_parallel_loading_workers: int | None = None
153
154
155
156
157
158
159
    """Maximum number of parallel loading workers when loading model
    sequentially in multiple batches. To avoid RAM OOM when using tensor
    parallel and large models."""

    disable_custom_all_reduce: bool = False
    """Disable the custom all-reduce kernel and fall back to NCCL."""

160
    enable_dbo: bool = False
161
    """Enable dual batch overlap for the model executor."""
162
163

    dbo_decode_token_threshold: int = 32
164
165
166
167
168
169
170
171
172
    """The threshold for dual batch overlap for batches only containing decodes.
    If the number of tokens in the request is greater than this threshold,
    microbatching will be used. Otherwise, the request will be processed in a
    single batch."""
    dbo_prefill_token_threshold: int = 512  # TODO(lucas): tune
    """The threshold for dual batch overlap for batches that contain one or more
    prefills. If the number of tokens in the request is greater than this
    threshold, microbatching will be used. Otherwise, the request will be
    processed in a single batch."""
173

174
175
176
177
    disable_nccl_for_dp_synchronization: bool = False
    """Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py 
    to use Gloo instead of NCCL for its all reduce"""

178
179
180
    ray_workers_use_nsight: bool = False
    """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

181
    ray_runtime_env: RuntimeEnv | None = None
182
183
    """Ray runtime environment to pass to distributed workers."""

184
    placement_group: PlacementGroup | None = None
185
186
    """ray distributed model workers placement group."""

187
188
189
    distributed_executor_backend: (
        str | DistributedExecutorBackend | type[ExecutorBase] | None
    ) = None
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    """Backend to use for distributed model
    workers, either "ray" or "mp" (multiprocessing). If the product
    of pipeline_parallel_size and tensor_parallel_size is less than
    or equal to the number of GPUs available, "mp" will be used to
    keep processing on a single host. Otherwise, this will default
    to "ray" if Ray is installed and fail otherwise. Note that tpu
    only support Ray for distributed inference."""

    worker_cls: str = "auto"
    """The full name of the worker class to use. If "auto", the worker class
    will be determined based on the platform."""
    sd_worker_cls: str = "auto"
    """The full name of the worker class to use for speculative decoding.
    If "auto", the worker class will be determined based on the platform."""
    worker_extension_cls: str = ""
    """The full name of the worker extension class to use. The worker extension
    class is dynamically inherited by the worker class. This is used to inject
    new attributes and methods to the worker class for use in collective_rpc
    calls."""

210
    world_size: int = Field(init=False)
211
212
213
214
215
    """world_size is TPxPP, it affects the number of workers we create."""

    rank: int = 0
    """Global rank in distributed setup."""

216
    _data_parallel_master_port_list: list[int] = Field(default_factory=list)
217
218
219
220
    """List of open port auto-queried for data parallel messaging.
    Set to be private as it's not intended to be configured by users.
    """

221
222
223
224
225
    decode_context_parallel_size: int = 1
    """Number of decode context parallel groups, because the world size does
    not change by dcp, it simply reuse the GPUs of TP group, and tp_size
    needs to be divisible by dcp_size."""

226
    _api_process_count: int = Field(default=1, gt=0)
227
228
229
230
231
232
233
234
    """
    The number of API processes initialized.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

235
    _api_process_rank: int = Field(default=0, ge=-1)
236
237
238
239
240
241
242
243
244
    """
    The rank of this API process, or `-1` for engine core processes
    under API server scale-out.

    Note:
        This is an internal config that is only valid for and
        should only be set by API server scale-out.
    """

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    @model_validator(mode="after")
    def _validate_parallel_config(self) -> Self:
        if self._api_process_rank >= self._api_process_count:
            raise ValueError(
                "Invalid value of `_api_process_rank`. "
                f"Expected to be `-1` or `[0, {self._api_process_count})`, "
                f"but found: {self._api_process_rank}"
            )

        if self.data_parallel_size_local > self.data_parallel_size:
            raise ValueError(
                f"data_parallel_size_local ({self.data_parallel_size_local}) "
                f"must be <= data_parallel_size ({self.data_parallel_size})"
            )

        if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
            raise ValueError(
                "data_parallel_external_lb can only be set when data_parallel_size > 1"
            )

        if self.enable_eplb:
            if not current_platform.is_cuda():
                raise ValueError(
                    "Expert parallelism load balancing is only supported on "
                    "CUDA devices now."
                )
            if not self.enable_expert_parallel:
                raise ValueError("enable_expert_parallel must be True to use EPLB.")
            if self.tensor_parallel_size * self.data_parallel_size <= 1:
                raise ValueError(
                    "EPLB requires tensor_parallel_size or data_parallel_size "
                    f"to be greater than 1, but got "
                    f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
                )
        else:
            if self.eplb_config.num_redundant_experts != 0:
                raise ValueError(
                    "num_redundant_experts is set to "
                    f"{self.eplb_config.num_redundant_experts} but EPLB is not "
                    "enabled. Either enable EPLB or unset "
                    "num_redundant_experts."
                )

        return self

290
291
292
293
294
295
296
297
298
299
300
301
    @property
    def world_size_across_dp(self) -> int:
        """world_size_across_dp is TPxPPxDP, it is the size of the world
        including data parallelism."""
        return self.world_size * self.data_parallel_size

    def get_next_dp_init_port(self) -> int:
        """
        We might need to initialize process groups in multiple
        processes that is related to data parallelism,
        e.g. both in the worker and in the engine, which
        can live in different processes. To avoid port conflicts, we
302
303
        pop a new port from the prepared port list each time we need to
        initialize a new process group related to data parallelism.
304
        """
305
306
307
308
309
310
        if self._data_parallel_master_port_list:
            answer = self._data_parallel_master_port_list.pop()
        else:
            answer = self.data_parallel_master_port
            self.data_parallel_master_port += 1

311
312
313
314
315
316
317
318
319
320
321
322
323
        return answer

    def stateless_init_dp_group(self) -> ProcessGroup:
        # NOTE: In high-concurrency scenarios multiple processes
        # can pick the same (currently free) port through a race
        # condition when calling `get_open_port()`. When the first
        # process binds the port the others will subsequently fail
        # with `torch.distributed.DistNetworkError: EADDRINUSE`.
        # To make the initialization more robust we retry a few times
        # with a fresh port whenever this specific error is observed.
        from torch.distributed import DistNetworkError

        from vllm.distributed.utils import (
324
325
            stateless_init_torch_distributed_process_group,
        )
326
327

        max_retries = 5
328
        last_exc: Exception | None = None
329
330
331
332
333
334
335
336
        for _ in range(max_retries):
            try:
                # use gloo since the engine process might not have cuda device
                return stateless_init_torch_distributed_process_group(
                    self.data_parallel_master_ip,
                    self.get_next_dp_init_port(),
                    self.data_parallel_rank,
                    self.data_parallel_size,
337
338
                    backend="gloo",
                )
339
340
341
            except DistNetworkError as e:
                # We only want to retry when the root cause is EADDRINUSE.
                if "EADDRINUSE" in str(e):
342
                    logger.warning("Address already in use. Retrying with a new port.")
343
344
345
346
347
348
349
350
                    last_exc = e
                    continue  # try again with a new port
                raise e

        # If we get here all retries have failed.
        assert last_exc is not None
        raise last_exc

351
352
353
354
355
356
357
358
359
360
361
    # The all_reduce at the end of attention (during o_proj) means that
    # inputs are replicated across each rank of the tensor parallel group.
    # If using expert-parallelism with DeepEP All2All ops, replicated
    # tokens results in useless duplicate computation and communication.
    #
    # In this case, ensure the input to the experts is sequence parallel
    # to avoid the excess work.
    #
    # Not needed for pplx-kernels as it can handle duplicate input tokens.
    @property
    def use_sequence_parallel_moe(self) -> bool:
362
        return (
363
            self.all2all_backend
364
365
366
367
368
369
370
371
372
373
            in (
                "allgather_reducescatter",
                "naive",
                "deepep_high_throughput",
                "deepep_low_latency",
            )
            and self.enable_expert_parallel
            and self.tensor_parallel_size > 1
            and self.data_parallel_size > 1
        )
374

375
    @staticmethod
376
377
    def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
        tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
378
379
380
381
382
383
384
385
386
        # dp rank 0: has_unfinished_seqs=True
        # dp rank 1: has_unfinished_seqs=False
        # aggregated: has_unfinished_seqs=True
        # so this is an OR operation, i.e. MAX in integers
        torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
        aggregated_has_unfinished = bool(tensor.item())
        return aggregated_has_unfinished

    @staticmethod
387
    def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
388
389
        if kv_cache_memory == -1:
            kv_cache_memory = torch.iinfo(torch.int64).max
390
        tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
391
392
393
394
395
396
397
398
399
400
401
402
        # we cannot use broadcast for stateless dp group since it depends
        # on global rank
        torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
        return tensor.item()

    def compute_hash(self):
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
403
404
405

        This hash is also used for DP worker configuration validation
        to prevent hangs from mismatched collective communication patterns.
406
407
408
409
410
411
        """
        factors: list[Any] = []
        factors.append(self.pipeline_parallel_size)
        factors.append(self.tensor_parallel_size)
        factors.append(self.enable_expert_parallel)
        factors.append(self.data_parallel_size)
412
        factors.append(self.all2all_backend)
413
414
415
416
417
418
        factors.append(self.enable_eplb)
        if self.enable_eplb:
            factors.append(self.eplb_config.log_balancedness)
            factors.append(self.eplb_config.window_size)
            factors.append(self.eplb_config.step_interval)
            factors.append(self.eplb_config.num_redundant_experts)
419
420
421
        return hashlib.sha256(str(factors).encode()).hexdigest()

    def __post_init__(self) -> None:
422
423
424
425
426
427
428
429
430
431
        # Set all2all_backend from env var if not specified, with deprecation warning
        if self.all2all_backend is None:
            self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
            if envs.is_set("VLLM_ALL2ALL_BACKEND"):
                logger.warning_once(
                    "VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
                    "will be removed in a future release. Please use the "
                    "--all2all-backend command-line argument instead."
                )

432
433
        # Forward deprecated fields to their new location
        if self.num_redundant_experts is not None:
434
            self.eplb_config.num_redundant_experts = self.num_redundant_experts
435
436
437
438
            logger.warning_once(
                "num_redundant_experts is deprecated and has been replaced "
                "with eplb_config.num_redundant_experts. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
439
440
                "have no effect."
            )
441
442
443
444
445
446
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
            logger.warning_once(
                "eplb_window_size is deprecated and has been replaced "
                "with eplb_config.window_size. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
447
448
                "have no effect."
            )
449
450
451
452
453
454
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
            logger.warning_once(
                "eplb_step_interval is deprecated and has been replaced "
                "with eplb_config.step_interval. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
455
456
                "have no effect."
            )
457
458
459
460
461
462
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness
            logger.warning_once(
                "eplb_log_balancedness is deprecated and has been replaced "
                "with eplb_config.log_balancedness. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
463
464
                "have no effect."
            )
465
466

        # Continue with the rest of the initialization
467
        self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
468

469
470
471
472
        if self.distributed_executor_backend == "external_launcher":
            logger.info("Using external launcher for distributed inference.")
            self.world_size *= self.data_parallel_size

473
474
        if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
            # Data parallel was specified in the engine args.
475
476
477
            if self.distributed_executor_backend == "external_launcher":
                # For external launcher,
                # we need to set the data parallel rank automatically
478
479
480
481
482
483
484
                self.data_parallel_rank = int(os.environ["RANK"]) // (
                    self.world_size // self.data_parallel_size
                )
                logger.info(
                    "Set data_parallel_rank to %d automatically.",
                    self.data_parallel_rank,
                )
485
486
            if not self._data_parallel_master_port_list:
                self._data_parallel_master_port_list = get_open_ports_list(5)
487
            self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
488
489
490
491

            if not (0 <= self.data_parallel_rank < self.data_parallel_size):
                raise ValueError(
                    f"data_parallel_rank ({self.data_parallel_rank})"
492
493
                    f" must be in the range [0, {self.data_parallel_size})"
                )
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        else:
            # Otherwise fall back to env vars (e.g. for offline SPMD case).
            self.data_parallel_size = envs.VLLM_DP_SIZE
            self.data_parallel_rank = envs.VLLM_DP_RANK
            self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
            self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
            self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

        if self.distributed_executor_backend == "external_launcher":
            os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
            logger.info("Disabling V1 multiprocessing for external launcher.")

        if self.distributed_executor_backend is None and self.world_size > 1:
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

            from vllm.executor import ray_utils
511

512
513
            backend: DistributedExecutorBackend = "mp"
            ray_found = ray_utils.ray_is_available()
514
            if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
515
                backend = "uni"
516
517
518
519
            elif (
                current_platform.is_cuda()
                and cuda_device_count_stateless() < self.world_size
            ):
520
                if not ray_found:
521
522
523
524
525
526
527
                    raise ValueError(
                        "Unable to load Ray: "
                        f"{ray_utils.ray_import_err}. Ray is "
                        "required for multi-node inference, "
                        "please install Ray with `pip install "
                        "ray`."
                    )
528
529
                backend = "ray"
            elif self.data_parallel_backend == "ray":
530
531
532
533
                logger.info(
                    "Using ray distributed inference because "
                    "data_parallel_backend is ray"
                )
534
535
536
537
538
539
                backend = "ray"
            elif ray_found:
                if self.placement_group:
                    backend = "ray"
                else:
                    from ray import is_initialized as ray_is_initialized
540

541
542
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
543

544
545
546
                        if get_current_placement_group():
                            backend = "ray"
            self.distributed_executor_backend = backend
547
            logger.debug("Defaulting to use %s for distributed inference", backend)
548
549
550
551
552
553
554
555

        if self.distributed_executor_backend is None and self.world_size == 1:
            self.distributed_executor_backend = "uni"

    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
556
557
            and getattr(self.distributed_executor_backend, "uses_ray", False)
        )
558

559
    @model_validator(mode="after")
560
561
562
563
    def _verify_args(self) -> Self:
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase
        from vllm.platforms import current_platform
564
565
566
567
568
569
570
571
572

        if (
            self.distributed_executor_backend is not None
            and not isinstance(self.distributed_executor_backend, str)
            and not (
                isinstance(self.distributed_executor_backend, type)
                and issubclass(self.distributed_executor_backend, ExecutorBase)
            )
        ):
573
574
575
            raise ValueError(
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
576
                "values are 'ray', 'mp' 'uni', 'external_launcher', "
577
578
                " custom ExecutorBase subclass or its import path."
            )
579
580
        if self.use_ray:
            from vllm.executor import ray_utils
581

582
583
584
585
586
587
            ray_utils.assert_ray_available()

        if not current_platform.use_custom_allreduce():
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce kernel because it is not "
588
589
                "supported on current platform."
            )
590
        if self.ray_workers_use_nsight and not self.use_ray:
591
592
593
            raise ValueError(
                "Unable to use nsight profiling unless workers run with Ray."
            )
594
595

        return self