Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1aa162e0
Unverified
Commit
1aa162e0
authored
Mar 26, 2025
by
cyyever
Committed by
GitHub
Mar 26, 2025
Browse files
Apply torchfix (#15532)
Signed-off-by:
cyy
<
cyyever@outlook.com
>
parent
cf5c8f16
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
11 deletions
+15
-11
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-3
vllm/lora/models.py
vllm/lora/models.py
+3
-1
vllm/model_executor/models/nemotron.py
vllm/model_executor/models/nemotron.py
+3
-3
vllm/model_executor/models/phi4mm_utils.py
vllm/model_executor/models/phi4mm_utils.py
+6
-3
vllm/multimodal/image.py
vllm/multimodal/image.py
+1
-1
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
1aa162e0
...
@@ -884,9 +884,8 @@ def _sdpa_attention(
...
@@ -884,9 +884,8 @@ def _sdpa_attention(
for
i
,
seq_len
in
enumerate
(
seq_lens
):
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
end
=
start
+
seq_len
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_math
=
True
,
with
torch
.
nn
.
attention
.
sdpa_kernel
(
enable_flash
=
False
,
torch
.
nn
.
attention
.
SDPBackend
.
MATH
):
enable_mem_efficient
=
False
):
sub_out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
sub_out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
...
...
vllm/lora/models.py
View file @
1aa162e0
...
@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
...
@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f
" target modules in
{
expected_lora_modules
}
"
f
" target modules in
{
expected_lora_modules
}
"
f
" but received
{
unexpected_modules
}
."
f
" but received
{
unexpected_modules
}
."
f
" Please verify that the loaded LoRA module is correct"
)
f
" Please verify that the loaded LoRA module is correct"
)
tensors
=
torch
.
load
(
lora_bin_file_path
,
map_location
=
device
)
tensors
=
torch
.
load
(
lora_bin_file_path
,
map_location
=
device
,
weights_only
=
True
)
else
:
else
:
raise
ValueError
(
f
"
{
lora_dir
}
doesn't contain tensors"
)
raise
ValueError
(
f
"
{
lora_dir
}
doesn't contain tensors"
)
...
...
vllm/model_executor/models/nemotron.py
View file @
1aa162e0
...
@@ -63,8 +63,8 @@ def _cast_if_autocast_enabled(*args):
...
@@ -63,8 +63,8 @@ def _cast_if_autocast_enabled(*args):
if
not
torch
.
is_autocast_enabled
():
if
not
torch
.
is_autocast_enabled
():
return
args
return
args
else
:
else
:
return
torch
.
cuda
.
amp
.
autocast_mode
.
_cast
(
return
torch
.
amp
.
autocast_mode
.
_cast
(
args
,
torch
.
get_autocast_gpu_dtype
())
args
,
device_type
=
"cuda"
,
dtype
=
torch
.
get_autocast_gpu_dtype
())
class
NemotronLayerNorm1P
(
nn
.
LayerNorm
):
class
NemotronLayerNorm1P
(
nn
.
LayerNorm
):
...
@@ -89,7 +89,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
...
@@ -89,7 +89,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
residual
=
x
residual
=
x
args
=
_cast_if_autocast_enabled
(
x
,
self
.
normalized_shape
,
args
=
_cast_if_autocast_enabled
(
x
,
self
.
normalized_shape
,
self
.
weight
+
1
,
self
.
bias
,
self
.
eps
)
self
.
weight
+
1
,
self
.
bias
,
self
.
eps
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
False
):
x
=
torch
.
nn
.
functional
.
layer_norm
(
*
args
)
x
=
torch
.
nn
.
functional
.
layer_norm
(
*
args
)
return
x
if
residual
is
None
else
(
x
,
residual
)
return
x
if
residual
is
None
else
(
x
,
residual
)
...
...
vllm/model_executor/models/phi4mm_utils.py
View file @
1aa162e0
...
@@ -1766,9 +1766,12 @@ class MultiHeadedAttention(nn.Module):
...
@@ -1766,9 +1766,12 @@ class MultiHeadedAttention(nn.Module):
if
mask
.
dtype
!=
q
.
dtype
:
if
mask
.
dtype
!=
q
.
dtype
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_flash
=
True
,
with
torch
.
nn
.
attention
.
sdpa_kernel
([
enable_math
=
True
,
torch
.
nn
.
attention
.
SDPBackend
.
FLASH_ATTENTION
,
enable_mem_efficient
=
True
):
torch
.
nn
.
attention
.
SDPBackend
.
EFFICIENT_ATTENTION
,
torch
.
nn
.
attention
.
SDPBackend
.
MATH
,
torch
.
nn
.
attention
.
SDPBackend
.
CUDNN_ATTENTION
,
]):
x
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
x
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
q
,
k
,
k
,
...
...
vllm/multimodal/image.py
View file @
1aa162e0
...
@@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
...
@@ -149,7 +149,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
return
self
.
load_bytes
(
base64
.
b64decode
(
data
))
return
self
.
load_bytes
(
base64
.
b64decode
(
data
))
def
load_file
(
self
,
filepath
:
Path
)
->
torch
.
Tensor
:
def
load_file
(
self
,
filepath
:
Path
)
->
torch
.
Tensor
:
return
torch
.
load
(
filepath
)
return
torch
.
load
(
filepath
,
weights_only
=
True
)
def
encode_base64
(
self
,
media
:
torch
.
Tensor
)
->
str
:
def
encode_base64
(
self
,
media
:
torch
.
Tensor
)
->
str
:
return
base64
.
b64encode
(
media
.
numpy
()).
decode
(
'utf-8'
)
return
base64
.
b64encode
(
media
.
numpy
()).
decode
(
'utf-8'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment