Unverified Commit c6ca5159 authored by Li-Yongwen's avatar Li-Yongwen Committed by GitHub
Browse files

[Bugfix] fix device_name for routing replay (#34336)


Signed-off-by: default avatarliyongwen <1310439159@qq.com>
parent c0615a29
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -132,7 +133,7 @@ class RoutedExpertsCapturer: ...@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
self._device_buffer = torch.zeros( self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok), (max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device=current_platform.device_type,
) )
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment