"docs/source/serving/usage_stats.md" did not exist on "7bc94a0fddcd62d20b40390a7efb69c7a221ae5b"
Unverified Commit 792cbd64 authored by nkm-meta's avatar nkm-meta Committed by GitHub
Browse files

Add platform method to enable custom collective ops registration (#34760)


Signed-off-by: default avatarNaina Kuruballi Mahesh <nainakm@meta.com>
parent 2ed4722e
...@@ -385,8 +385,10 @@ class GroupCoordinator: ...@@ -385,8 +385,10 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6 self.cpu_group, 1 << 22, 6
) )
# TODO(#35915): Remove is_tpu() check once tpu_inference
# overrides use_custom_op_collectives() to return True.
self.use_custom_op_call = ( self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu() current_platform.is_tpu() or current_platform.use_custom_op_collectives()
) )
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
......
...@@ -574,9 +574,13 @@ class CudaPlatformBase(Platform): ...@@ -574,9 +574,13 @@ class CudaPlatformBase(Platform):
return True return True
@classmethod @classmethod
def num_compute_units(cls, device_id=0): def num_compute_units(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count return torch.cuda.get_device_properties(device_id).multi_processor_count
@classmethod
def use_custom_op_collectives(cls) -> bool:
return True
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
...@@ -654,6 +654,15 @@ class Platform: ...@@ -654,6 +654,15 @@ class Platform:
""" """
return False return False
@classmethod
def use_custom_op_collectives(cls) -> bool:
"""
Whether this platform should use torch.ops.vllm.* custom ops for collectives.
Returns False by default - platforms must explicitly opt-in.
"""
return False
@classmethod @classmethod
def use_sync_weight_loader(cls) -> bool: def use_sync_weight_loader(cls) -> bool:
""" """
......
...@@ -820,5 +820,9 @@ class RocmPlatform(Platform): ...@@ -820,5 +820,9 @@ class RocmPlatform(Platform):
return True return True
@classmethod @classmethod
def num_compute_units(cls, device_id=0): def num_compute_units(cls, device_id: int = 0) -> int:
return torch.cuda.get_device_properties(device_id).multi_processor_count return torch.cuda.get_device_properties(device_id).multi_processor_count
@classmethod
def use_custom_op_collectives(cls) -> bool:
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment