Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8f4b313c
Unverified
Commit
8f4b313c
authored
Oct 15, 2025
by
wangxiyuan
Committed by
GitHub
Oct 15, 2025
Browse files
[Misc] rename torch_dtype to dtype (#26695)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
f93e3480
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
22 additions
and
24 deletions
+22
-24
vllm/model_executor/models/longcat_flash.py
vllm/model_executor/models/longcat_flash.py
+2
-2
vllm/model_executor/models/nano_nemotron_vl.py
vllm/model_executor/models/nano_nemotron_vl.py
+2
-2
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+3
-3
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+1
-1
vllm/model_executor/models/transformers_pooling.py
vllm/model_executor/models/transformers_pooling.py
+1
-1
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+2
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+2
-2
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+2
-2
vllm/utils/__init__.py
vllm/utils/__init__.py
+6
-8
No files found.
vllm/model_executor/models/longcat_flash.py
View file @
8f4b313c
...
@@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig):
...
@@ -114,7 +114,7 @@ class FlashConfig(PretrainedConfig):
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
mla_scale_q_lora
=
False
,
mla_scale_q_lora
=
False
,
mla_scale_kv_lora
=
False
,
mla_scale_kv_lora
=
False
,
torch_
dtype
=
"bfloat16"
,
dtype
=
"bfloat16"
,
params_dtype
=
"bfloat16"
,
params_dtype
=
"bfloat16"
,
router_dtype
=
"float32"
,
router_dtype
=
"float32"
,
router_bias
=
False
,
router_bias
=
False
,
...
@@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig):
...
@@ -130,7 +130,7 @@ class FlashConfig(PretrainedConfig):
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
tie_word_embeddings
=
tie_word_embeddings
,
torch_dtype
=
torch_
dtype
,
dtype
=
dtype
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
router_dtype
=
router_dtype
,
router_dtype
=
router_dtype
,
topk_method
=
topk_method
,
topk_method
=
topk_method
,
...
...
vllm/model_executor/models/nano_nemotron_vl.py
View file @
8f4b313c
...
@@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2(
...
@@ -987,7 +987,7 @@ class NemotronH_Nano_VL_V2(
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
)
self
.
vision_model
=
self
.
get_vit_model_from_radio_config
(
config
).
to
(
self
.
vision_model
=
self
.
get_vit_model_from_radio_config
(
config
).
to
(
self
.
language_model
.
config
.
torch_
dtype
self
.
language_model
.
config
.
dtype
)
)
# Construct the vision projection.
# Construct the vision projection.
...
@@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2(
...
@@ -1008,7 +1008,7 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation
(),
ReLUSquaredActivation
(),
nn
.
Linear
(
vision_projection_hidden_size
,
llm_hidden_size
,
bias
=
False
),
nn
.
Linear
(
vision_projection_hidden_size
,
llm_hidden_size
,
bias
=
False
),
)
)
self
.
mlp1
=
self
.
mlp1
.
to
(
self
.
language_model
.
config
.
torch_
dtype
)
self
.
mlp1
=
self
.
mlp1
.
to
(
self
.
language_model
.
config
.
dtype
)
self
.
config
=
config
self
.
config
=
config
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
...
...
vllm/model_executor/models/qwen3_next.py
View file @
8f4b313c
...
@@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -338,7 +338,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
group_size
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
device
=
current_platform
.
current_device
(),
dtype
=
config
.
torch_
dtype
,
dtype
=
config
.
dtype
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
out_proj
=
RowParallelLinear
(
...
@@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module):
...
@@ -847,7 +847,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1
,
1
,
1
,
1
,
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
config
.
torch_
dtype
,
dtype
=
config
.
dtype
,
),
),
)
)
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
self
.
ffn_layer_scale
=
torch
.
nn
.
Parameter
(
...
@@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module):
...
@@ -855,7 +855,7 @@ class Qwen3NextDecoderLayer(nn.Module):
1
,
1
,
1
,
1
,
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
config
.
torch_
dtype
,
dtype
=
config
.
dtype
,
),
),
)
)
...
...
vllm/model_executor/models/transformers.py
View file @
8f4b313c
...
@@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
...
@@ -530,7 +530,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
with
init_on_device_without_buffers
(
"meta"
):
with
init_on_device_without_buffers
(
"meta"
):
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
self
.
config
,
self
.
config
,
torch_
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)
...
...
vllm/model_executor/models/transformers_pooling.py
View file @
8f4b313c
...
@@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
...
@@ -157,7 +157,7 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
with
torch
.
device
(
"meta"
):
with
torch
.
device
(
"meta"
):
seq_cls_model
=
AutoModelForSequenceClassification
.
from_config
(
seq_cls_model
=
AutoModelForSequenceClassification
.
from_config
(
self
.
config
,
self
.
config
,
torch_
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)
...
...
vllm/platforms/cuda.py
View file @
8f4b313c
...
@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
...
@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
return
supported
return
supported
@
classmethod
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_
dtype
:
torch
.
dtype
):
def
check_if_supports_dtype
(
cls
,
dtype
:
torch
.
dtype
):
if
torch_
dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
not
cls
.
has_device_capability
(
80
):
if
not
cls
.
has_device_capability
(
80
):
capability
=
cls
.
get_device_capability
()
capability
=
cls
.
get_device_capability
()
gpu_name
=
cls
.
get_device_name
()
gpu_name
=
cls
.
get_device_name
()
...
...
vllm/platforms/interface.py
View file @
8f4b313c
...
@@ -563,7 +563,7 @@ class Platform:
...
@@ -563,7 +563,7 @@ class Platform:
return
False
return
False
@
classmethod
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_
dtype
:
torch
.
dtype
):
def
check_if_supports_dtype
(
cls
,
dtype
:
torch
.
dtype
):
"""
"""
Check if the dtype is supported by the current platform.
Check if the dtype is supported by the current platform.
"""
"""
...
...
vllm/platforms/rocm.py
View file @
8f4b313c
...
@@ -484,8 +484,8 @@ class RocmPlatform(Platform):
...
@@ -484,8 +484,8 @@ class RocmPlatform(Platform):
return
True
return
True
@
classmethod
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_
dtype
:
torch
.
dtype
):
def
check_if_supports_dtype
(
cls
,
dtype
:
torch
.
dtype
):
if
torch_
dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
not
cls
.
has_device_capability
(
80
):
if
not
cls
.
has_device_capability
(
80
):
capability
=
cls
.
get_device_capability
()
capability
=
cls
.
get_device_capability
()
gpu_name
=
cls
.
get_device_name
()
gpu_name
=
cls
.
get_device_name
()
...
...
vllm/platforms/xpu.py
View file @
8f4b313c
...
@@ -236,8 +236,8 @@ class XPUPlatform(Platform):
...
@@ -236,8 +236,8 @@ class XPUPlatform(Platform):
return
torch
.
xpu
.
device_count
()
return
torch
.
xpu
.
device_count
()
@
classmethod
@
classmethod
def
check_if_supports_dtype
(
cls
,
torch_
dtype
:
torch
.
dtype
):
def
check_if_supports_dtype
(
cls
,
dtype
:
torch
.
dtype
):
if
torch_
dtype
==
torch
.
bfloat16
:
# noqa: SIM102
if
dtype
==
torch
.
bfloat16
:
# noqa: SIM102
device_name
=
cls
.
get_device_name
().
lower
()
device_name
=
cls
.
get_device_name
().
lower
()
# client gpu a770
# client gpu a770
if
device_name
.
count
(
"a770"
)
>
0
:
if
device_name
.
count
(
"a770"
)
>
0
:
...
...
vllm/utils/__init__.py
View file @
8f4b313c
...
@@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash(
...
@@ -806,7 +806,7 @@ def create_kv_caches_with_random_flash(
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_
dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
generic_kv_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
generic_kv_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
assert
cache_layout
in
(
"NHD"
,
"HND"
)
assert
cache_layout
in
(
"NHD"
,
"HND"
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
...
@@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash(
...
@@ -819,7 +819,7 @@ def create_kv_caches_with_random_flash(
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
key_value_cache
=
torch
.
empty
(
size
=
kv_cache_allocation_shape
,
dtype
=
torch_
dtype
,
device
=
device
size
=
kv_cache_allocation_shape
,
dtype
=
dtype
,
device
=
device
).
permute
(
*
stride_order
)
).
permute
(
*
stride_order
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_value_cache
.
uniform_
(
-
scale
,
scale
)
key_value_cache
.
uniform_
(
-
scale
,
scale
)
...
@@ -851,14 +851,14 @@ def create_kv_caches_with_random(
...
@@ -851,14 +851,14 @@ def create_kv_caches_with_random(
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch_
dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_
dtype
).
element_size
()
x
=
16
//
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
:
list
[
torch
.
Tensor
]
=
[]
key_caches
:
list
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_
dtype
,
device
=
device
)
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
device
=
device
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_cache
.
uniform_
(
-
scale
,
scale
)
key_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
"fp8"
:
elif
cache_dtype
==
"fp8"
:
...
@@ -870,9 +870,7 @@ def create_kv_caches_with_random(
...
@@ -870,9 +870,7 @@ def create_kv_caches_with_random(
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
:
list
[
torch
.
Tensor
]
=
[]
value_caches
:
list
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
device
=
device
)
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
device
=
device
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
value_cache
.
uniform_
(
-
scale
,
scale
)
value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
"fp8"
:
elif
cache_dtype
==
"fp8"
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment