Unverified Commit 23eca9cf authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[model][refactor] remove cuda hard code in models and layers (#13658)

parent 437b76ff
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config) fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -238,7 +239,7 @@ def fused_marlin_moe( ...@@ -238,7 +239,7 @@ def fused_marlin_moe(
max_workspace_size = (max(2 * N, K) // 64) * 16 max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device=current_platform.device_type,
requires_grad=False) requires_grad=False)
if has_no_zp: if has_no_zp:
......
...@@ -30,6 +30,7 @@ import torch.nn as nn ...@@ -30,6 +30,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -650,8 +651,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -650,8 +651,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style, dtype) is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange( pos_freqs = self.base**(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type) /
self.rotary_dim) self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
...@@ -670,7 +675,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -670,7 +675,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor, t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda", device=current_platform.device_type,
dtype=torch.float32) dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale) cos = (freqs.cos() * self.mscale)
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
import torch.jit import torch.jit
import torch.nn as nn import torch.nn as nn
from vllm.platforms import current_platform
class SpecDecodeBaseSampler(nn.Module): class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification """Base class for samplers used for Speculative Decoding verification
...@@ -35,7 +37,7 @@ class SpecDecodeBaseSampler(nn.Module): ...@@ -35,7 +37,7 @@ class SpecDecodeBaseSampler(nn.Module):
def init_gpu_tensors(self, device: Union[int, str]) -> None: def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None assert self.num_accepted_tokens is None
if isinstance(device, int): if isinstance(device, int):
device = f"cuda:{device}" device = f"{current_platform.device_type}:{device}"
elif not isinstance(device, str): elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}") raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0, self.num_accepted_tokens = torch.tensor(0,
......
...@@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if param_name + "." in k: if param_name + "." in k:
quant_state[k] = temp_state_dict[k] quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state, device="cuda") return QuantState.from_dict(quant_state,
device=current_platform.device_type)
# Second iterate over all prequant and normal weights # Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state # pre quantized weights would have a quant_state
......
...@@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
...@@ -138,13 +139,13 @@ class ArcticMoE(nn.Module): ...@@ -138,13 +139,13 @@ class ArcticMoE(nn.Module):
torch.empty(self.num_experts, torch.empty(self.num_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
self.w2s = nn.Parameter( self.w2s = nn.Parameter(
torch.empty(self.num_experts, torch.empty(self.num_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
set_weight_attrs(self.ws, { set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
......
...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -98,13 +99,13 @@ class MiniCPMMoE(nn.Module): ...@@ -98,13 +99,13 @@ class MiniCPMMoE(nn.Module):
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
self.w2s = nn.Parameter( self.w2s = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
set_weight_attrs(self.ws, { set_weight_attrs(self.ws, {
......
...@@ -59,6 +59,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize, ...@@ -59,6 +59,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
...@@ -1184,7 +1185,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -1184,7 +1185,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_embedding(
self, self,
...@@ -1266,7 +1268,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1266,7 +1268,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_embedding(
self, self,
...@@ -1360,7 +1363,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1360,7 +1363,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_embedding(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment