Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f97ca671
"tests/vscode:/vscode.git/clone" did not exist on "af295e9b010ff2f7886cde2e5a41a4ef84d82ac1"
Unverified
Commit
f97ca671
authored
Feb 08, 2026
by
Andrey Talman
Committed by
GitHub
Feb 08, 2026
Browse files
[Release 2.10] Update to Torch 2.10 - final release (#30525)
parent
084aa19f
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
130 additions
and
78 deletions
+130
-78
.buildkite/image_build/image_build.yaml
.buildkite/image_build/image_build.yaml
+2
-1
CMakeLists.txt
CMakeLists.txt
+5
-5
cmake/external_projects/triton_kernels.cmake
cmake/external_projects/triton_kernels.cmake
+4
-4
pyproject.toml
pyproject.toml
+1
-1
requirements/build.txt
requirements/build.txt
+1
-1
requirements/cuda.txt
requirements/cuda.txt
+3
-3
requirements/rocm-build.txt
requirements/rocm-build.txt
+5
-6
requirements/test.in
requirements/test.in
+4
-4
requirements/test.txt
requirements/test.txt
+9
-5
tests/compile/test_aot_compile.py
tests/compile/test_aot_compile.py
+10
-30
tests/compile/test_dynamic_shapes_compilation.py
tests/compile/test_dynamic_shapes_compilation.py
+1
-3
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
+5
-0
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+1
-1
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+4
-4
vllm/envs.py
vllm/envs.py
+1
-1
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+1
-1
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+73
-8
No files found.
.buildkite/image_build/image_build.yaml
View file @
f97ca671
...
...
@@ -3,6 +3,7 @@ steps:
-
label
:
"
:docker:
Build
image"
key
:
image-build
depends_on
:
[]
timeout_in_minutes
:
600
commands
:
-
if [[ "$BUILDKITE_BRANCH" != "main" ]]; then .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $IMAGE_TAG; fi
-
if [[ "$BUILDKITE_BRANCH" == "main" ]]; then .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $IMAGE_TAG $IMAGE_TAG_LATEST; fi
...
...
CMakeLists.txt
View file @
f97ca671
...
...
@@ -56,8 +56,8 @@ endif()
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from docker/Dockerfile.rocm
#
set
(
TORCH_SUPPORTED_VERSION_CUDA
"2.
9.1
"
)
set
(
TORCH_SUPPORTED_VERSION_ROCM
"2.
9.1
"
)
set
(
TORCH_SUPPORTED_VERSION_CUDA
"2.
10.0
"
)
set
(
TORCH_SUPPORTED_VERSION_ROCM
"2.
10.0
"
)
#
# Try to find python package with an executable that exactly matches
...
...
cmake/external_projects/triton_kernels.cmake
View file @
f97ca671
# Install OpenAI triton_kernels from https://github.com/triton-lang/triton/tree/main/python/triton_kernels
set
(
DEFAULT_TRITON_KERNELS_TAG
"v3.
5
.0"
)
set
(
DEFAULT_TRITON_KERNELS_TAG
"v3.
6
.0"
)
# Set TRITON_KERNELS_SRC_DIR for use with local development with vLLM. We expect TRITON_KERNELS_SRC_DIR to
# be directly set to the triton_kernels python directory.
...
...
pyproject.toml
View file @
f97ca671
...
...
@@ -6,7 +6,7 @@ requires = [
"packaging>=24.2"
,
"setuptools>=77.0.3,<81.0.0"
,
"setuptools-scm>=8.0"
,
"torch == 2.
9.1
"
,
"torch == 2.
10.0
"
,
"wheel"
,
"jinja2"
,
"grpcio-tools==1.78.0"
,
...
...
requirements/build.txt
View file @
f97ca671
...
...
@@ -4,7 +4,7 @@ ninja
packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools-scm>=8
torch==2.
9.1
torch==2.
10.0
wheel
jinja2>=3.1.6
regex
...
...
requirements/cuda.txt
View file @
f97ca671
...
...
@@ -5,9 +5,9 @@ numba == 0.61.2 # Required for N-gram speculative decoding
# Dependencies for NVIDIA GPUs
ray[cgraph]>=2.48.0
torch==2.
9.1
torchaudio==2.
9.1
torch==2.
10.0
torchaudio==2.
10.0
# These must be updated alongside torch
torchvision==0.2
4.1
# Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
torchvision==0.2
5.0
# Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.3
requirements/rocm-build.txt
View file @
f97ca671
# Common dependencies
-r common.txt
--extra-index-url https://download.pytorch.org/whl/rocm6.4
torch==2.9.1
torchvision==0.24.1
torchaudio==2.9.1
triton==3.5.1
--extra-index-url https://download.pytorch.org/whl/test/rocm7.0
torch==2.10.0
torchvision==0.25.0
torchaudio==2.10.0
triton==3.6.0
cmake>=3.26.1,<4
packaging>=24.2
setuptools>=77.0.3,<80.0.0
...
...
requirements/test.in
View file @
f97ca671
...
...
@@ -24,10 +24,10 @@ sentence-transformers>=5.2.0 # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
tblib # for pickling test exceptions
timm
=
=1.0.17 # required for internvl and gemma3n-mm test
torch==2.
9.1
torchaudio==2.
9.1
torchvision==0.2
4.1
timm
>
=1.0.17 # required for internvl and gemma3n-mm test
torch==2.
10.0
torchaudio==2.
10.0
torchvision==0.2
5.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.0 # required for voxtral test
...
...
requirements/test.txt
View file @
f97ca671
...
...
@@ -155,6 +155,10 @@ coverage==7.10.6
# via pytest-cov
cramjam==2.9.0
# via fastparquet
cuda-bindings==12.9.4
# via torch
cuda-pathfinder==1.3.3
# via cuda-bindings
cupy-cuda12x==13.6.0
# via ray
cycler==0.12.1
...
...
@@ -631,7 +635,7 @@ nvidia-nvjitlink-cu12==12.9.86
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvshmem-cu12==3.
3.20
nvidia-nvshmem-cu12==3.
4.5
# via torch
nvidia-nvtx-cu12==12.9.79
# via torch
...
...
@@ -1163,7 +1167,7 @@ tomli==2.2.1
# via schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.
9.1
+cu129
torch==2.
10.0
+cu129
# via
# -r requirements/test.in
# accelerate
...
...
@@ -1192,7 +1196,7 @@ torch==2.9.1+cu129
# torchvision
# vector-quantize-pytorch
# vocos
torchaudio==2.
9.1
+cu129
torchaudio==2.
10.0
+cu129
# via
# -r requirements/test.in
# encodec
...
...
@@ -1205,7 +1209,7 @@ torchmetrics==1.7.4
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.2
4.1
+cu129
torchvision==0.2
5.0
+cu129
# via
# -r requirements/test.in
# lightly
...
...
@@ -1247,7 +1251,7 @@ transformers==4.57.5
# transformers-stream-generator
transformers-stream-generator==0.0.5
# via -r requirements/test.in
triton==3.
5.1
triton==3.
6.0
# via torch
tritonclient==2.64.0
# via -r requirements/test.in
...
...
tests/compile/test_aot_compile.py
View file @
f97ca671
...
...
@@ -90,9 +90,7 @@ def use_vllm_config(vllm_config: VllmConfig):
yield
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_no_dynamo_cache_entry
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
vllm_config
=
make_vllm_config
()
...
...
@@ -116,9 +114,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
assert
torch
.
allclose
(
actual
,
expected
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_force_aot_load
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
,
monkeypatch
.
context
()
as
m
:
args
=
(
torch
.
randn
(
10
,
10
),)
...
...
@@ -132,9 +128,7 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
CompiledMod
(
vllm_config
=
vllm_config
)(
*
args
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_save_and_load
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
args
=
(
torch
.
randn
(
10
,
10
),)
...
...
@@ -162,9 +156,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
assert
torch
.
allclose
(
ret
,
expected
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_cache_load_returns_tuple_consistency
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that cache loading correctly handles the returns_tuple logic.
...
...
@@ -223,9 +215,7 @@ def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch):
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_cache_load_returns_tuple_consistency_tuple_output
(
monkeypatch
:
pytest
.
MonkeyPatch
,
):
...
...
@@ -294,9 +284,7 @@ def test_cache_load_returns_tuple_consistency_tuple_output(
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_shape_env
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test that the shape environment is correctly serialized and preserved
...
...
@@ -333,9 +321,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
assert
guards_string
==
" - s77 <= 42
\n
- Eq(Mod(s77, 2), 0)"
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_partition_wrapper_applied_on_aot_load
(
monkeypatch
:
pytest
.
MonkeyPatch
,
vllm_tmp_cache
:
Path
,
mocker
):
...
...
@@ -426,9 +412,7 @@ def test_partition_wrapper_applied_on_aot_load(
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
@
create_new_process_for_each_test
(
"spawn"
)
def
test_gpt2_cache_hit
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
...
...
@@ -492,9 +476,7 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
symbolic_shapes_module
.
make_symbol
=
original_make_symbol
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
class
TestStandaloneCompiledArtifacts
:
def
test_init
(
self
):
cache
=
StandaloneCompiledArtifacts
()
...
...
@@ -668,9 +650,7 @@ class TestStandaloneCompiledArtifacts:
assert
len
(
restored_cache
.
loaded_submodule_store
)
==
0
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
class
TestStandaloneCompiledArtifactsIntegration
:
def
test_add_pickle_unpickle
(
self
):
cache
=
StandaloneCompiledArtifacts
()
...
...
tests/compile/test_dynamic_shapes_compilation.py
View file @
f97ca671
...
...
@@ -39,9 +39,7 @@ def get_test_models():
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"use_bytecode_hook"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"requires torch 2.10"
)
def
test_dynamic_shapes_compilation
(
monkeypatch
,
model_name
,
...
...
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
View file @
f97ca671
...
...
@@ -14,6 +14,7 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
class
SimpleLinear
(
nn
.
Module
):
...
...
@@ -60,6 +61,10 @@ def setup_cuda():
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"hidden_size,latent_size"
,
[(
256
,
128
),
(
128
,
64
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
is_torch_equal_or_newer
(
"2.10.0"
),
reason
=
"Test fails with PyTorch 2.10.0 see: https://github.com/vllm-project/vllm/issues/33995"
,
)
def
test_routed_input_transform_inside_vs_outside
(
num_tokens
:
int
,
hidden_size
:
int
,
...
...
vllm/compilation/compiler_interface.py
View file @
f97ca671
...
...
@@ -233,7 +233,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
from
torch._inductor
import
standalone_compile
supports_aot
=
is_torch_equal_or_newer
(
"2.10.0
.dev
"
)
supports_aot
=
is_torch_equal_or_newer
(
"2.10.0"
)
if
not
supports_aot
and
envs
.
VLLM_USE_MEGA_AOT_ARTIFACT
:
logger
.
error
(
...
...
vllm/compilation/decorators.py
View file @
f97ca671
...
...
@@ -333,7 +333,7 @@ def _support_torch_compile(
)
->
None
:
def
mark_dynamic
(
arg
:
torch
.
Tensor
,
dims
:
list
[
int
])
->
None
:
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
if
is_torch_equal_or_newer
(
"2.10.0
.dev
"
):
if
is_torch_equal_or_newer
(
"2.10.0"
):
for
dim
in
dims
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dim
,
hint_override
=
arg
.
size
()[
dim
]
...
...
@@ -373,7 +373,7 @@ def _support_torch_compile(
if
isinstance
(
arg
,
torch
.
Tensor
):
# In case dims is specified with negative indexing
dims
=
[
arg
.
ndim
+
dim
if
dim
<
0
else
dim
for
dim
in
dims
]
if
is_torch_equal_or_newer
(
"2.10.0
.dev
"
):
if
is_torch_equal_or_newer
(
"2.10.0"
):
for
dim
in
dims
:
torch
.
_dynamo
.
decorators
.
mark_unbacked
(
arg
,
dim
,
hint_override
=
arg
.
size
()[
dim
]
...
...
@@ -525,9 +525,9 @@ def _support_torch_compile(
fx_config_patches
[
"backed_size_oblivious"
]
=
True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0
.dev
+
# assume_32bit_indexing is only available in torch 2.10.0+
inductor_config_patches
=
{}
if
is_torch_equal_or_newer
(
"2.10.0
.dev
"
):
if
is_torch_equal_or_newer
(
"2.10.0"
):
inductor_config_patches
[
"assume_32bit_indexing"
]
=
(
self
.
compilation_config
.
dynamic_shapes_config
.
assume_32_bit_indexing
)
...
...
vllm/envs.py
View file @
f97ca671
...
...
@@ -271,7 +271,7 @@ def use_aot_compile() -> bool:
default_value
=
(
"1"
if
is_torch_equal_or_newer
(
"2.1
0
.0.dev"
)
and
not
disable_compile_cache
()
if
is_torch_equal_or_newer
(
"2.1
1
.0.dev"
)
and
not
disable_compile_cache
()
else
"0"
)
...
...
vllm/model_executor/layers/batch_invariant.py
View file @
f97ca671
...
...
@@ -974,7 +974,7 @@ def enable_batch_invariant_mode():
)
reduced_precision_val
=
(
(
False
,
False
)
if
is_torch_equal_or_newer
(
"2.10.0
.dev
"
)
else
False
(
False
,
False
)
if
is_torch_equal_or_newer
(
"2.10.0"
)
else
False
)
torch
.
backends
.
cuda
.
matmul
.
allow_fp16_reduced_precision_reduction
=
(
reduced_precision_val
...
...
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
f97ca671
...
...
@@ -27,9 +27,21 @@ logger = init_logger(__name__)
if
has_triton_kernels
():
try
:
import
triton_kernels.swiglu
from
triton_kernels.matmul_ogs
import
FnSpecs
,
FusedActivation
,
matmul_ogs
from
triton_kernels.routing
import
RoutingData
,
routing
,
routing_from_bitmatrix
from
triton_kernels.tensor
import
Bitmatrix
from
triton_kernels.matmul_ogs
import
(
FnSpecs
,
FusedActivation
,
GatherIndx
,
RoutingData
,
ScatterIndx
,
matmul_ogs
,
)
from
triton_kernels.tensor
import
(
BIT
,
Bitmatrix
,
SparseMatrix
,
make_ragged_tensor_metadata
,
)
from
triton_kernels.topk
import
topk
except
(
AttributeError
,
ImportError
)
as
e
:
logger
.
error
(
"Failed to import Triton kernels. Please make sure your triton "
...
...
@@ -78,6 +90,58 @@ def pack_bitmatrix(
tl
.
store
(
bitmatrix_ptrs
,
y
,
mask
=
offsets_m
[:,
None
]
<
n_rows
)
def
legacy_routing_from_bitmatrix
(
bitmatrix
:
"Bitmatrix"
,
expt_scal
:
torch
.
Tensor
,
expt_indx
:
torch
.
Tensor
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
sparse_logits
=
SparseMatrix
(
indx
=
expt_indx
,
vals
=
expt_scal
,
mask
=
bitmatrix
)
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing
(
logits
:
torch
.
Tensor
,
n_expts_act
:
int
,
sm_first
:
bool
=
False
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sparse_logits
=
topk
(
logits
,
n_expts_act
,
apply_softmax
=
not
sm_first
)
return
legacy_routing_from_bitmatrix
(
sparse_logits
.
mask
,
sparse_logits
.
vals
,
sparse_logits
.
indx
,
logits
.
shape
[
-
1
],
n_expts_act
,
)
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
w1
,
# Tensor or triton_kernels.Tensor
...
...
@@ -91,7 +155,7 @@ def triton_kernel_moe_forward(
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
routing_data
,
gather_idx
,
scatter_idx
=
routing
(
routing_data
,
gather_idx
,
scatter_idx
=
legacy_
routing
(
gating_output
,
topk
,
sm_first
=
not
renormalize
)
...
...
@@ -168,9 +232,10 @@ def triton_kernel_fused_experts(
output_tensor
=
_resize_cache
(
output_tensor
,
(
batch_dim
,
M
,
K
))
act
=
FusedActivation
(
FnSpecs
(
"swiglu"
,
triton_kernels
.
swiglu
.
swiglu_fn
,
(
"alpha"
,
"limit"
)),
FnSpecs
(
"swiglu"
,
triton_kernels
.
swiglu
.
swiglu_fn
,
(
"alpha"
,
"limit"
),
reduction_n
=
2
),
(
swiglu_alpha
,
swiglu_limit
),
2
,
)
gammas
=
routing_data
.
gate_scal
if
routing_data
else
None
...
...
@@ -232,12 +297,12 @@ def make_routing_data(
bitmatrix_shape
=
[
n_rows
,
bm_cols
*
32
]
bitmatrix_shape_max
=
[
n_rows
,
None
]
bitmatrix
=
Bitmatrix
(
bitmatrix
,
shape
=
bitmatrix_shape
,
shape_max
=
bitmatrix_shape_max
,
scratchpad
=
None
bitmatrix
,
dtype
=
BIT
,
shape
=
bitmatrix_shape
,
shape_max
=
bitmatrix_shape_max
)
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights
=
torch
.
where
(
topk_ids
==
-
1
,
-
1.0
,
topk_weights
)
routing_data
,
gather_indx
,
scatter_indx
=
routing_from_bitmatrix
(
routing_data
,
gather_indx
,
scatter_indx
=
legacy_
routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids
,
num_local_experts
,
num_topk
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment