Unverified Commit f9bc5a06 authored by Mengqing Cao's avatar Mengqing Cao Committed by GitHub
Browse files

[Bugfix] Fix triton import with local TritonPlaceholder (#17446)


Signed-off-by: default avatarMengqing Cao <cmq0113@163.com>
parent 05e1f964
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
# ruff: noqa: E501 # ruff: noqa: E501
import torch import torch
import triton
from einops import rearrange from einops import rearrange
from packaging import version from packaging import version
from vllm.triton_utils import triton
from .ssd_bmm import _bmm_chunk_fwd from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
# ruff: noqa: E501 # ruff: noqa: E501
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
@triton.autotune( @triton.autotune(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
from typing import Optional, Type from typing import Optional, Type
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
def is_weak_contiguous(x: torch.Tensor): def is_weak_contiguous(x: torch.Tensor):
......
...@@ -7,8 +7,6 @@ import os ...@@ -7,8 +7,6 @@ import os
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -8,10 +8,9 @@ import os ...@@ -8,10 +8,9 @@ import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils.importing import HAS_TRITON from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
TritonPlaceholder)
__all__ = ["HAS_TRITON"] if HAS_TRITON:
import triton
import triton.language as tl
else:
triton = TritonPlaceholder()
tl = TritonLanguagePlaceholder()
__all__ = ["HAS_TRITON", "triton", "tl"]
...@@ -16,32 +16,34 @@ if not HAS_TRITON: ...@@ -16,32 +16,34 @@ if not HAS_TRITON:
logger.info("Triton not installed or not compatible; certain GPU-related" logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.") " functions will not be available.")
class TritonPlaceholder(types.ModuleType):
class TritonPlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton") def __init__(self):
self.jit = self._dummy_decorator("jit") super().__init__("triton")
self.autotune = self._dummy_decorator("autotune") self.jit = self._dummy_decorator("jit")
self.heuristics = self._dummy_decorator("heuristics") self.autotune = self._dummy_decorator("autotune")
self.language = TritonLanguagePlaceholder() self.heuristics = self._dummy_decorator("heuristics")
logger.warning_once( self.language = TritonLanguagePlaceholder()
"Triton is not installed. Using dummy decorators. " logger.warning_once(
"Install it via `pip install triton` to enable kernel" "Triton is not installed. Using dummy decorators. "
"compilation.") "Install it via `pip install triton` to enable kernel"
" compilation.")
def _dummy_decorator(self, name):
def _dummy_decorator(self, name):
def decorator(func=None, **kwargs):
if func is None: def decorator(*args, **kwargs):
return lambda f: f if args and callable(args[0]):
return func return args[0]
return lambda f: f
return decorator
return decorator
class TritonLanguagePlaceholder(types.ModuleType):
def __init__(self): class TritonLanguagePlaceholder(types.ModuleType):
super().__init__("triton.language")
self.constexpr = None def __init__(self):
self.dtype = None super().__init__("triton.language")
self.int64 = None self.constexpr = None
self.dtype = None
self.int64 = None
...@@ -3,10 +3,9 @@ from typing import Optional ...@@ -3,10 +3,9 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import triton
import triton.language as tl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import torch.nn as nn import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
...@@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader ...@@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment