Unverified Commit 4e2d95e3 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware][ROCM] using current_platform.is_rocm (#9642)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 34a99416
...@@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip, print_warning_once from vllm.platforms import current_platform
from vllm.utils import print_warning_once
class GPTQMarlinState(Enum): class GPTQMarlinState(Enum):
...@@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz # If rocm, normalize the weights and scales to e4m3fnuz
if is_hip(): if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
......
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.utils import is_hip from vllm.platforms import current_platform
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
...@@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
if is_hip(): if current_platform.is_rocm():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=max_w_scale, weight_scale=max_w_scale,
...@@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
if is_hip(): if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
......
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter) ModelWeightParameter)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
if is_hip(): if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter, ...@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = (not current_platform.has_device_capability(89) self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN) or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm # Disable marlin for rocm
if is_hip(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
def create_weights( def create_weights(
...@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz. # If rocm, use float8_e4m3fnuz.
if is_hip(): if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
...@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype # If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \ fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn if current_platform.is_rocm() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype) dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
...@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz # If rocm, normalize the weights and scales to e4m3fnuz
if is_hip(): if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
......
...@@ -4,16 +4,16 @@ import torch ...@@ -4,16 +4,16 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
if current_platform.is_rocm() else None
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm # cutlass is not supported on Rocm
if is_hip(): if current_platform.is_rocm():
return False return False
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
......
...@@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
...@@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.transformer.h[layer_idx], nn.Identity): if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn layer_self_attn = self.transformer.h[layer_idx].attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity): if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
...@@ -423,7 +423,7 @@ class LlamaModel(nn.Module): ...@@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
if not isinstance(self.layers[layer_idx], nn.Identity): if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn layer_self_attn = self.layers[layer_idx].self_attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -12,7 +12,7 @@ import cloudpickle ...@@ -12,7 +12,7 @@ import cloudpickle
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free, from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp) supports_multimodal, supports_pp)
...@@ -247,7 +247,7 @@ def _try_load_model_cls( ...@@ -247,7 +247,7 @@ def _try_load_model_cls(
model_arch: str, model_arch: str,
model: _BaseRegisteredModel, model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]: ) -> Optional[Type[nn.Module]]:
if is_hip(): if current_platform.is_rocm():
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not " raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.") "supported by ROCm for now.")
......
...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
...@@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if not isinstance(self.model.layers[layer_idx], nn.Identity): if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip(): if current_platform.is_rocm():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value # quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting # which is consistent with the practice of setting
......
...@@ -314,10 +314,6 @@ class PyObjectCache: ...@@ -314,10 +314,6 @@ class PyObjectCache:
self._index = 0 self._index = 0
def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
...@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless( ...@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled(): if not torch.cuda._is_compiled():
return 0 return 0
if is_hip(): if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count # ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0 # This requires a sufficiently modern version of Torch 2.4.0
raw_count = torch.cuda._device_count_amdsmi() if (hasattr( raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal ...@@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import ( from vllm.prompt_adapter.worker_manager import (
...@@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams ...@@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available, flatten_2d_lists, is_pin_memory_available,
supports_dynamo, weak_ref_tensor) supports_dynamo, weak_ref_tensor)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
...@@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager( self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model)) self.model))
if self.kv_cache_dtype == "fp8" and is_hip(): if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
# Currently only ROCm accepts kv-cache scaling factors # Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated # via quantization_param_path and this will be deprecated
# in the future. # in the future.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment