Unverified Commit f9bc5a06 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Bugfix] Fix triton import with local TritonPlaceholder (#17446)


Signed-off-by: default avatarMengqing Cao <cmq0113@163.com>
parent 05e1f964
...@@ -10,12 +10,12 @@ from typing import Any, TypedDict ...@@ -10,12 +10,12 @@ from typing import Any, TypedDict
import ray import ray
import torch import torch
import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
......
...@@ -4,11 +4,11 @@ import itertools ...@@ -4,11 +4,11 @@ import itertools
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn from torch import nn
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module): class HuggingFaceRMSNorm(nn.Module):
......
...@@ -6,13 +6,13 @@ import time ...@@ -6,13 +6,13 @@ import time
# Import DeepGEMM functions # Import DeepGEMM functions
import deep_gemm import deep_gemm
import torch import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions # Import vLLM functions
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton
# Copied from # Copied from
......
...@@ -5,11 +5,11 @@ import random ...@@ -5,11 +5,11 @@ import random
import pytest import pytest
import torch import torch
import triton
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
......
# SPDX-License-Identifier: Apache-2.0
import sys
import types
from unittest import mock
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
TritonPlaceholder)
def test_triton_placeholder_is_module():
triton = TritonPlaceholder()
assert isinstance(triton, types.ModuleType)
assert triton.__name__ == "triton"
def test_triton_language_placeholder_is_module():
triton_language = TritonLanguagePlaceholder()
assert isinstance(triton_language, types.ModuleType)
assert triton_language.__name__ == "triton.language"
def test_triton_placeholder_decorators():
triton = TritonPlaceholder()
@triton.jit
def foo(x):
return x
@triton.autotune
def bar(x):
return x
@triton.heuristics
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_decorators_with_args():
triton = TritonPlaceholder()
@triton.jit(debug=True)
def foo(x):
return x
@triton.autotune(configs=[], key="x")
def bar(x):
return x
@triton.heuristics(
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_language():
lang = TritonLanguagePlaceholder()
assert isinstance(lang, types.ModuleType)
assert lang.__name__ == "triton.language"
assert lang.constexpr is None
assert lang.dtype is None
assert lang.int64 is None
def test_triton_placeholder_language_from_parent():
triton = TritonPlaceholder()
lang = triton.language
assert isinstance(lang, TritonLanguagePlaceholder)
def test_no_triton_fallback():
# clear existing triton modules
sys.modules.pop("triton", None)
sys.modules.pop("triton.language", None)
sys.modules.pop("vllm.triton_utils", None)
sys.modules.pop("vllm.triton_utils.importing", None)
# mock triton not being installed
with mock.patch.dict(sys.modules, {"triton": None}):
from vllm.triton_utils import HAS_TRITON, tl, triton
assert HAS_TRITON is False
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
def blocksparse_flash_attn_varlen_fwd( def blocksparse_flash_attn_varlen_fwd(
......
...@@ -8,7 +8,8 @@ from functools import lru_cache ...@@ -8,7 +8,8 @@ from functools import lru_cache
import numpy as np import numpy as np
import torch import torch
import triton
from vllm.triton_utils import triton
class csr_matrix: class csr_matrix:
......
...@@ -7,11 +7,10 @@ ...@@ -7,11 +7,10 @@
# - Thomas Parnell <tpa@zurich.ibm.com> # - Thomas Parnell <tpa@zurich.ibm.com>
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd from .prefix_prefill import context_attention_fwd
......
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
# Static kernels parameters # Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
......
...@@ -30,10 +30,8 @@ It supports page size >= 1. ...@@ -30,10 +30,8 @@ It supports page size >= 1.
import logging import logging
import triton
import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
......
...@@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features: ...@@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features:
from typing import Optional from typing import Optional
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
from typing import Optional from typing import Optional
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
""" """
Utilities for Punica kernel construction. Utilities for Punica kernel construction.
""" """
import triton from vllm.triton_utils import tl, triton
import triton.language as tl
@triton.jit @triton.jit
......
...@@ -6,8 +6,6 @@ import os ...@@ -6,8 +6,6 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton
import triton.language as tl
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
......
...@@ -2,11 +2,10 @@ ...@@ -2,11 +2,10 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import triton
import triton.language as tl
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import round_up from vllm.utils import round_up
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton
import triton.language as tl
from einops import rearrange from einops import rearrange
from vllm.triton_utils import tl, triton
@triton.jit @triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
......
...@@ -4,13 +4,11 @@ ...@@ -4,13 +4,11 @@
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch import torch
import triton
import triton.language as tl
from packaging import version from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON, tl, triton
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0")) >= version.parse("3.0.0"))
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
import math import math
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
@triton.autotune( @triton.autotune(
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
# ruff: noqa: E501,SIM102 # ruff: noqa: E501,SIM102
import torch import torch
import triton
import triton.language as tl
from packaging import version from packaging import version
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
import math import math
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
from .mamba_ssm import softplus from .mamba_ssm import softplus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment