Unverified Commit 3982bc2c authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

[ROCm] Enable DeepEP ROCm as all2allbackend for AMD GPUs. (#34692)


Signed-off-by: default avatarTej Kiran <vpolamre@amd.com>
Co-authored-by: default avatarTej Kiran <vpolamre@amd.com>
parent 02eec7ec
...@@ -44,7 +44,7 @@ ENV DEBIAN_FRONTEND=noninteractive ...@@ -44,7 +44,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies # Install Python and other dependencies
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \ && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev liblzma-dev pkg-config \
&& for i in 1 2 3; do \ && for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \ add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
......
...@@ -10,6 +10,7 @@ import vllm.envs as envs ...@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.flashinfer import ( from vllm.utils.flashinfer import (
has_flashinfer_nvlink_one_sided, has_flashinfer_nvlink_one_sided,
has_flashinfer_nvlink_two_sided, has_flashinfer_nvlink_two_sided,
...@@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
assert num_rdma_bytes is not None assert num_rdma_bytes is not None
assert num_qps_per_rank is not None assert num_qps_per_rank is not None
return dict( # TODO: remove platform-specific logic
# once ROCm DeepEP is updated with the latest APIs.
kwargs = dict(
group=self.cpu_group, group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes, num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False, low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
) )
if not current_platform.is_rocm():
kwargs.update(
explicitly_destroy=True,
)
return kwargs
def get_handle(self, kwargs): def get_handle(self, kwargs):
assert len(kwargs) == 0, ( assert len(kwargs) == 0, (
...@@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
) )
assert num_rdma_bytes is not None assert num_rdma_bytes is not None
return dict( # TODO: remove platform-specific logic
# once ROCm DeepEP is updated with the latest APIs.
kwargs = dict(
group=self.cpu_group, group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes, num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes, num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True, low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank, num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
) )
if not current_platform.is_rocm():
kwargs.update(
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
return kwargs
def get_handle(self, kwargs): def get_handle(self, kwargs):
""" """
......
...@@ -346,7 +346,7 @@ class FusedMoEQuantConfig: ...@@ -346,7 +346,7 @@ class FusedMoEQuantConfig:
@property @property
def use_fp8_w8a8(self) -> bool: def use_fp8_w8a8(self) -> bool:
return self.quant_dtype == torch.float8_e4m3fn return self.quant_dtype == current_platform.fp8_dtype()
@property @property
def use_int8_w8a8(self) -> bool: def use_int8_w8a8(self) -> bool:
...@@ -566,7 +566,7 @@ def fp8_w8a8_moe_quant_config( ...@@ -566,7 +566,7 @@ def fp8_w8a8_moe_quant_config(
Construct a quant config for fp8 activations and fp8 weights. Construct a quant config for fp8 activations and fp8 weights.
""" """
return FusedMoEQuantConfig.make( return FusedMoEQuantConfig.make(
torch.float8_e4m3fn, current_platform.fp8_dtype(),
w1_scale=w1_scale, w1_scale=w1_scale,
g1_alphas=g1_alphas, g1_alphas=g1_alphas,
w2_scale=w2_scale, w2_scale=w2_scale,
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, moe_kernel_quantize_input,
normalize_batched_scales_shape, normalize_batched_scales_shape,
) )
from vllm.platforms import current_platform
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id, dbo_current_ubatch_id,
dbo_enabled, dbo_enabled,
...@@ -290,23 +291,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): ...@@ -290,23 +291,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
# Dispatch # Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( if current_platform.is_rocm():
a1, (
dispatch_topk_ids, expert_x,
self.max_tokens_per_rank, expert_num_tokens,
num_experts, handle,
use_fp8=self.use_fp8_dispatch, _,
round_scale=self.use_ue8m0_dispatch, hook,
use_ue8m0=self.use_ue8m0_dispatch, ) = self.buffer.low_latency_dispatch(
**(dict(use_nvfp4=True) if use_nvfp4 else dict()), a1,
**( dispatch_topk_ids,
dict(x_global_scale=qc_a1_gscale_or_scale) self.max_tokens_per_rank,
if qc_a1_gscale_or_scale is not None num_experts,
else dict() use_fp8=self.use_fp8_dispatch,
), async_finish=False,
async_finish=False, return_recv_hook=True,
return_recv_hook=True, )
) else:
(
expert_x,
expert_num_tokens,
handle,
_,
hook,
) = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)
self.handles[a2a_idx] = handle self.handles[a2a_idx] = handle
return ( return (
......
...@@ -1017,6 +1017,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular): ...@@ -1017,6 +1017,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
] ]
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
...@@ -1046,7 +1047,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular): ...@@ -1046,7 +1047,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
compute_type = tl.float16 compute_type = tl.float16
elif hidden_states.dtype == torch.float32: elif hidden_states.dtype == torch.float32:
compute_type = tl.float32 compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn: elif hidden_states.dtype == current_platform.fp8_dtype():
compute_type = tl.bfloat16 compute_type = tl.bfloat16
else: else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
......
...@@ -1616,7 +1616,7 @@ def _get_config_quant_dtype( ...@@ -1616,7 +1616,7 @@ def _get_config_quant_dtype(
fused_experts_impl. fused_experts_impl.
""" """
if use_fp8_w8a8: if use_fp8_w8a8:
return torch.float8_e4m3fn return current_platform.fp8_dtype()
elif use_int8_w8a8: elif use_int8_w8a8:
return torch.int8 return torch.int8
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
......
...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ...@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize, per_tensor_dequantize,
) )
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -265,7 +266,7 @@ def moe_kernel_quantize_input( ...@@ -265,7 +266,7 @@ def moe_kernel_quantize_input(
# weights are already dequantized, and we proceed with normal # weights are already dequantized, and we proceed with normal
# activation quantization below. # activation quantization below.
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == current_platform.fp8_dtype():
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment