Unverified Commit c6ca5159 authored by Li-Yongwen's avatar Li-Yongwen Committed by GitHub
Browse files

[Bugfix] fix device_name for routing replay (#34336)


Signed-off-by: default avatarliyongwen <1310439159@qq.com>
parent c0615a29
......@@ -20,6 +20,7 @@ import torch
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
......@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32,
device="cuda",
device=current_platform.device_type,
)
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment