Unverified Commit 4e2d95e3 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware][ROCM] using current_platform.is_rocm (#9642)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 34a99416
...@@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip, print_warning_once from vllm.platforms import current_platform
from vllm.utils import print_warning_once
class GPTQMarlinState(Enum): class GPTQMarlinState(Enum):
...@@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz # If rocm, normalize the weights and scales to e4m3fnuz
if is_hip(): if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
......
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.utils import is_hip from vllm.platforms import current_platform
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
...@@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
if is_hip(): if current_platform.is_rocm():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=max_w_scale, weight_scale=max_w_scale,
...@@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
if is_hip(): if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter) ModelWeightParameter)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
if is_hip(): if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter, ...@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = (not current_platform.has_device_capability(89) self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN) or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm # Disable marlin for rocm
if is_hip(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
def create_weights( def create_weights(
...@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz. # If rocm, use float8_e4m3fnuz.
if is_hip(): if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
...@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype # If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \ fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn if current_platform.is_rocm() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype) dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
...@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz # If rocm, normalize the weights and scales to e4m3fnuz
if is_hip(): if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
......
...@@ -4,16 +4,16 @@ import torch ...@@ -4,16 +4,16 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
if current_platform.is_rocm() else None
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm # cutlass is not supported on Rocm
if is_hip(): if current_platform.is_rocm():
return False return False
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
......
...@@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
...@@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.transformer.h[layer_idx], nn.Identity): if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn layer_self_attn = self.transformer.h[layer_idx].attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity): if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
...@@ -423,7 +423,7 @@ class LlamaModel(nn.Module): ...@@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
if not isinstance(self.layers[layer_idx], nn.Identity): if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn layer_self_attn = self.layers[layer_idx].self_attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -12,7 +12,7 @@ import cloudpickle ...@@ -12,7 +12,7 @@ import cloudpickle
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free, from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp) supports_multimodal, supports_pp)
...@@ -247,7 +247,7 @@ def _try_load_model_cls( ...@@ -247,7 +247,7 @@ def _try_load_model_cls(
model_arch: str, model_arch: str,
model: _BaseRegisteredModel, model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]: ) -> Optional[Type[nn.Module]]:
if is_hip(): if current_platform.is_rocm():
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not " raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.") "supported by ROCm for now.")
......
...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
...@@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity): if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -314,10 +314,6 @@ class PyObjectCache: ...@@ -314,10 +314,6 @@ class PyObjectCache:
self._index = 0 self._index = 0
def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
...@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless( ...@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled(): if not torch.cuda._is_compiled():
return 0 return 0
if is_hip(): if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count # ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0 # This requires a sufficiently modern version of Torch 2.4.0
raw_count = torch.cuda._device_count_amdsmi() if (hasattr( raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal ...@@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import ( from vllm.prompt_adapter.worker_manager import (
...@@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams ...@@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available, flatten_2d_lists, is_pin_memory_available,
supports_dynamo, weak_ref_tensor) supports_dynamo, weak_ref_tensor)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
...@@ -737,13 +738,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -737,13 +738,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
family of functions. family of functions.
Args: Args:
num_seqs (int): Number of sequences scheduled to run. num_seqs (int): Number of sequences scheduled to run.
max_decode_seq_len (int): Greatest of all the decode sequence max_decode_seq_len (int): Greatest of all the decode sequence
lengths. Used only in checking the viablility of using lengths. Used only in checking the viablility of using
CUDA graphs. CUDA graphs.
max_encoder_seq_len (int, optional): Greatest of all the encode max_encoder_seq_len (int, optional): Greatest of all the encode
sequence lengths. Defaults to 0. Used only in checking the sequence lengths. Defaults to 0. Used only in checking the
viability of using CUDA graphs. viability of using CUDA graphs.
Returns: Returns:
int: Returns the determined number of padding sequences. If int: Returns the determined number of padding sequences. If
CUDA graphs is not viable, returns -1. CUDA graphs is not viable, returns -1.
...@@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager( self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model)) self.model))
if self.kv_cache_dtype == "fp8" and is_hip(): if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
# Currently only ROCm accepts kv-cache scaling factors # Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated # via quantization_param_path and this will be deprecated
# in the future. # in the future.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment