Unverified Commit 1bd32bc8 authored by Mathis Felardos's avatar Mathis Felardos Committed by GitHub
Browse files

[Config][Disaggregated] Add timeout configuration for the torch.store and add...


[Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367)
Signed-off-by: default avatarMathis Felardos <mathis@mistral.ai>
parent 128bf752
...@@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel): ...@@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection # The KV connector port, used to build distributed connection
kv_port: int = 14579 kv_port: int = 14579
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel): ...@@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
class CompilationLevel: class CompilationLevel:
# constants for the levels of the compilation process # constants for the levels of the compilation process
......
...@@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase): ...@@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
self.device = self._select_device(device) self.device = self._select_device(device)
# build distributed connection and send/recv implementation # build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create( self.group = StatelessProcessGroup.create(
host=self.config.kv_ip, host=self.config.kv_ip,
port=self.config.kv_port + port_offset, port=self.config.kv_port + port_offset,
rank=self.kv_rank, rank=self.kv_rank,
world_size=self.kv_parallel_size, world_size=self.kv_parallel_size,
store_timeout=store_timeout,
) )
# add a barrier to make sure the connection is initiated properly # add a barrier to make sure the connection is initiated properly
self.group.barrier() self.group.barrier()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import datetime
import pickle import pickle
import time import time
from collections import deque from collections import deque
...@@ -217,6 +218,7 @@ class StatelessProcessGroup: ...@@ -217,6 +218,7 @@ class StatelessProcessGroup:
rank: int, rank: int,
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup": ) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not """A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. pollute the global state.
...@@ -238,6 +240,7 @@ class StatelessProcessGroup: ...@@ -238,6 +240,7 @@ class StatelessProcessGroup:
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=(rank == 0), is_master=(rank == 0),
timeout=datetime.timedelta(seconds=store_timeout),
) )
return StatelessProcessGroup( return StatelessProcessGroup(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment