Unverified Commit f9bc5a06 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Bugfix] Fix triton import with local TritonPlaceholder (#17446)


Signed-off-by: default avatarMengqing Cao <cmq0113@163.com>
parent 05e1f964
......@@ -6,10 +6,11 @@
# ruff: noqa: E501
import torch
import triton
from einops import rearrange
from packaging import version
from vllm.triton_utils import triton
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
......
......@@ -6,8 +6,8 @@
# ruff: noqa: E501
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
@triton.autotune(
......
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
......
......@@ -3,8 +3,8 @@
from typing import Optional, Type
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
def is_weak_contiguous(x: torch.Tensor):
......
......@@ -7,8 +7,6 @@ import os
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
......@@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
......
......@@ -8,10 +8,9 @@ import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = logging.getLogger(__name__)
......
# SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils.importing import HAS_TRITON
from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
TritonPlaceholder)
__all__ = ["HAS_TRITON"]
if HAS_TRITON:
import triton
import triton.language as tl
else:
triton = TritonPlaceholder()
tl = TritonLanguagePlaceholder()
__all__ = ["HAS_TRITON", "triton", "tl"]
......@@ -16,32 +16,34 @@ if not HAS_TRITON:
logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.")
class TritonPlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
"compilation.")
def _dummy_decorator(self, name):
def decorator(func=None, **kwargs):
if func is None:
return lambda f: f
return func
return decorator
class TritonLanguagePlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None
class TritonPlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
" compilation.")
def _dummy_decorator(self, name):
def decorator(*args, **kwargs):
if args and callable(args[0]):
return args[0]
return lambda f: f
return decorator
class TritonLanguagePlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None
......@@ -3,10 +3,9 @@ from typing import Optional
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
......
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
......@@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment