Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2eb2b5a
Unverified
Commit
b2eb2b5a
authored
Jul 18, 2025
by
Richard Zou
Committed by
GitHub
Jul 18, 2025
Browse files
[Kernel] Apply torch.Tag.needs_fixed_stride_order only for torch==2.6.0 (#19346)
Signed-off-by:
rzou
<
zou3519@gmail.com
>
parent
21274ab4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
9 deletions
+19
-9
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+8
-4
vllm/attention/ops/rocm_aiter_mla.py
vllm/attention/ops/rocm_aiter_mla.py
+6
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-3
No files found.
csrc/torch_bindings.cpp
View file @
b2eb2b5a
...
...
@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
// so we need
// to override this for many GEMMs with the following tag. Otherwise,
// torch.compile will force all input tensors to be contiguous(), which
// will break many custom ops that require column-major weight matrices.
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
// to match exact eager-mode strides.
at
::
Tag
stride_tag
=
at
::
Tag
::
needs_fixed_stride_order
;
// This was a bug and PyTorch 2.7 has since fixed this.
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
#define stride_tag at::Tag::needs_fixed_stride_order
#else
#define stride_tag
#endif
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
...
...
vllm/attention/ops/rocm_aiter_mla.py
View file @
b2eb2b5a
...
...
@@ -6,7 +6,7 @@ from typing import Optional
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
def
get_aiter_mla_metadata
(
max_batch_size
:
int
,
block_size
:
int
,
...
...
@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
if
current_platform
.
is_rocm
():
if
is_torch_equal_or_newer
(
"2.7.0"
):
tags
=
()
else
:
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
direct_register_custom_op
(
op_name
=
"rocm_aiter_mla_decode_fwd"
,
op_func
=
mla_decode_fwd_impl
,
mutates_args
=
[
"o"
],
fake_impl
=
mla_decode_fwd_fake
,
tags
=
[
torch
.
Tag
.
needs_fixed_stride_order
]
)
tags
=
tags
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
b2eb2b5a
...
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
from
vllm.utils.deep_gemm
import
is_blackwell_deep_gemm_used
from
.rocm_aiter_fused_moe
import
is_rocm_aiter_moe_enabled
...
...
@@ -1056,7 +1056,8 @@ direct_register_custom_op(
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(()
if
is_torch_equal_or_newer
(
"2.7.0"
)
else
(
torch
.
Tag
.
needs_fixed_stride_order
,
)),
)
...
...
@@ -1122,7 +1123,8 @@ direct_register_custom_op(
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(()
if
is_torch_equal_or_newer
(
"2.7.0"
)
else
(
torch
.
Tag
.
needs_fixed_stride_order
,
)),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment