Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4e2d95e3
Unverified
Commit
4e2d95e3
authored
Oct 28, 2024
by
wangshuai09
Committed by
GitHub
Oct 28, 2024
Browse files
[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by:
wangshuai09
<
391746016@qq.com
>
parent
34a99416
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
31 additions
and
34 deletions
+31
-34
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+3
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+3
-3
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+1
-2
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+5
-5
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+3
-3
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+2
-2
vllm/model_executor/models/granite.py
vllm/model_executor/models/granite.py
+2
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-2
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-2
vllm/model_executor/models/solar.py
vllm/model_executor/models/solar.py
+2
-2
vllm/utils.py
vllm/utils.py
+1
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+5
-4
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
4e2d95e3
...
@@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -14,7 +14,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
class
GPTQMarlinState
(
Enum
):
class
GPTQMarlinState
(
Enum
):
...
@@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -150,7 +151,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
if
current_platform
.
is_rocm
():
# Normalize the weights and scales
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
4e2d95e3
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
...
@@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -40,7 +40,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
)
)
if
is_hip
():
if
current_platform
.
is_rocm
():
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
max_w_scale
,
weight_scale
=
max_w_scale
,
...
@@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -56,7 +56,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
weight
=
layer
.
weight
if
is_hip
():
if
current_platform
.
is_rocm
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
4e2d95e3
...
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
ModelWeightParameter
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -127,7 +126,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight
=
layer
.
weight
if
is_hip
():
if
current_platform
.
is_rocm
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
4e2d95e3
...
@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
...
@@ -26,7 +26,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
print_warning_once
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -123,7 +123,7 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
# Disable marlin for rocm
if
is_hip
():
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
self
.
use_marlin
=
False
def
create_weights
(
def
create_weights
(
...
@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -226,7 +226,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If rocm, use float8_e4m3fnuz.
# If rocm, use float8_e4m3fnuz.
if
is_hip
():
if
current_platform
.
is_rocm
():
weight
,
weight_scale
,
input_scale
=
\
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
...
@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -372,7 +372,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
torch
.
float8_e4m3fnuz
\
fp8_dtype
=
torch
.
float8_e4m3fnuz
\
if
is_hip
()
else
torch
.
float8_e4m3fn
if
current_platform
.
is_rocm
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -420,7 +420,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
if
current_platform
.
is_rocm
():
# Normalize the weights and scales
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
4e2d95e3
...
@@ -4,16 +4,16 @@ import torch
...
@@ -4,16 +4,16 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
# Input scaling factors are no longer optional in _scaled_mm starting
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
\
if
current_platform
.
is_rocm
()
else
None
def
cutlass_fp8_supported
()
->
bool
:
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
# cutlass is not supported on Rocm
if
is_hip
():
if
current_platform
.
is_rocm
():
return
False
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability_tuple
=
current_platform
.
get_device_capability
()
...
...
vllm/model_executor/models/exaone.py
View file @
4e2d95e3
...
@@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -49,9 +49,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.exaone
import
ExaoneConfig
from
vllm.transformers_utils.configs.exaone
import
ExaoneConfig
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
...
@@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -595,7 +595,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
not
isinstance
(
self
.
transformer
.
h
[
layer_idx
],
nn
.
Identity
):
if
not
isinstance
(
self
.
transformer
.
h
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
transformer
.
h
[
layer_idx
].
attn
layer_self_attn
=
self
.
transformer
.
h
[
layer_idx
].
attn
if
is_hip
():
if
current_platform
.
is_rocm
():
# The scaling factor convention we are assuming is
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# which is consistent with the practice of setting
...
...
vllm/model_executor/models/granite.py
View file @
4e2d95e3
...
@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
@@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -534,7 +534,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
if
current_platform
.
is_rocm
():
# The scaling factor convention we are assuming is
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# which is consistent with the practice of setting
...
...
vllm/model_executor/models/llama.py
View file @
4e2d95e3
...
@@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -50,8 +50,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
...
@@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
...
@@ -423,7 +423,7 @@ class LlamaModel(nn.Module):
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
if
not
isinstance
(
self
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
layer_self_attn
=
self
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
if
current_platform
.
is_rocm
():
# The scaling factor convention we are assuming is
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# which is consistent with the practice of setting
...
...
vllm/model_executor/models/registry.py
View file @
4e2d95e3
...
@@ -12,7 +12,7 @@ import cloudpickle
...
@@ -12,7 +12,7 @@ import cloudpickle
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.
util
s
import
is_hip
from
vllm.
platform
s
import
current_platform
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
supports_multimodal
,
supports_pp
)
supports_multimodal
,
supports_pp
)
...
@@ -247,7 +247,7 @@ def _try_load_model_cls(
...
@@ -247,7 +247,7 @@ def _try_load_model_cls(
model_arch
:
str
,
model_arch
:
str
,
model
:
_BaseRegisteredModel
,
model
:
_BaseRegisteredModel
,
)
->
Optional
[
Type
[
nn
.
Module
]]:
)
->
Optional
[
Type
[
nn
.
Module
]]:
if
is_hip
():
if
current_platform
.
is_rocm
():
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture '
{
model_arch
}
' is not "
raise
ValueError
(
f
"Model architecture '
{
model_arch
}
' is not "
"supported by ROCm for now."
)
"supported by ROCm for now."
)
...
...
vllm/model_executor/models/solar.py
View file @
4e2d95e3
...
@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -49,8 +49,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
...
@@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -558,7 +558,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
if
not
isinstance
(
self
.
model
.
layers
[
layer_idx
],
nn
.
Identity
):
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
layer_self_attn
=
self
.
model
.
layers
[
layer_idx
].
self_attn
if
is_hip
():
if
current_platform
.
is_rocm
():
# The scaling factor convention we are assuming is
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# which is consistent with the practice of setting
...
...
vllm/utils.py
View file @
4e2d95e3
...
@@ -314,10 +314,6 @@ class PyObjectCache:
...
@@ -314,10 +314,6 @@ class PyObjectCache:
self
.
_index
=
0
self
.
_index
=
0
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
"""Returns the maximum shared memory per thread block in bytes."""
...
@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
...
@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
if
not
torch
.
cuda
.
_is_compiled
():
if
not
torch
.
cuda
.
_is_compiled
():
return
0
return
0
if
is_hip
():
if
current_platform
.
is_rocm
():
# ROCm uses amdsmi instead of nvml for stateless device count
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
# This requires a sufficiently modern version of Torch 2.4.0
raw_count
=
torch
.
cuda
.
_device_count_amdsmi
()
if
(
hasattr
(
raw_count
=
torch
.
cuda
.
_device_count_amdsmi
()
if
(
hasattr
(
...
...
vllm/worker/model_runner.py
View file @
4e2d95e3
...
@@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
...
@@ -41,6 +41,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
,
MultiModalRegistry
)
MultiModalInputs
,
MultiModalRegistry
)
from
vllm.platforms
import
current_platform
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.worker_manager
import
(
from
vllm.prompt_adapter.worker_manager
import
(
...
@@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -49,7 +50,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.utils
import
(
DeviceMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
DeviceMemoryProfiler
,
PyObjectCache
,
async_tensor_h2d
,
flatten_2d_lists
,
is_hip
,
is_pin_memory_available
,
flatten_2d_lists
,
is_pin_memory_available
,
supports_dynamo
,
weak_ref_tensor
)
supports_dynamo
,
weak_ref_tensor
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
...
@@ -737,13 +738,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -737,13 +738,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
family of functions.
family of functions.
Args:
Args:
num_seqs (int): Number of sequences scheduled to run.
num_seqs (int): Number of sequences scheduled to run.
max_decode_seq_len (int): Greatest of all the decode sequence
max_decode_seq_len (int): Greatest of all the decode sequence
lengths. Used only in checking the viablility of using
lengths. Used only in checking the viablility of using
CUDA graphs.
CUDA graphs.
max_encoder_seq_len (int, optional): Greatest of all the encode
max_encoder_seq_len (int, optional): Greatest of all the encode
sequence lengths. Defaults to 0. Used only in checking the
sequence lengths. Defaults to 0. Used only in checking the
viability of using CUDA graphs.
viability of using CUDA graphs.
Returns:
Returns:
int: Returns the determined number of padding sequences. If
int: Returns the determined number of padding sequences. If
CUDA graphs is not viable, returns -1.
CUDA graphs is not viable, returns -1.
...
@@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1103,7 +1104,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
prompt_adapter_manager
.
create_prompt_adapter_manager
(
self
.
prompt_adapter_manager
.
create_prompt_adapter_manager
(
self
.
model
))
self
.
model
))
if
self
.
kv_cache_dtype
==
"fp8"
and
is_hip
():
if
self
.
kv_cache_dtype
==
"fp8"
and
current_platform
.
is_rocm
():
# Currently only ROCm accepts kv-cache scaling factors
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# via quantization_param_path and this will be deprecated
# in the future.
# in the future.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment