import ctypes import logging import os import platform import shutil from pathlib import Path import torch logger = logging.getLogger(__name__) def _get_compute_capability(): """Get the compute capability of the current GPU.""" if not torch.cuda.is_available(): return None # Get the current device device = torch.cuda.current_device() properties = torch.cuda.get_device_properties(device) # Return as integer (major * 10 + minor) return properties.major * 10 + properties.minor def _filter_compiled_extensions(file_list): """Filter and prioritize compiled extensions over Python source files.""" compiled_extensions = [".so", ".pyd", ".dll"] # Common compiled extension suffixes compiled_files = [] other_files = [] for file_path in file_list: path = Path(file_path) # Check if it's a compiled extension (including complex names like .abi3.so, .cpython-312.so) if any( str(path).endswith(ext) or ext in str(path) for ext in compiled_extensions ): compiled_files.append(file_path) else: other_files.append(file_path) # Return compiled files first, then others return compiled_files + other_files def _load_architecture_specific_ops(): """Load the appropriate common_ops library based on GPU architecture.""" import importlib.util import sys from pathlib import Path compute_capability = _get_compute_capability() logger.debug( f"[sgl_kernel] GPU Detection: compute_capability = {compute_capability}" ) # Get the directory where sgl_kernel is installed sgl_kernel_dir = Path(__file__).parent logger.debug(f"[sgl_kernel] sgl_kernel directory: {sgl_kernel_dir}") # Determine which version to load based on GPU architecture if compute_capability == 90: ops_subdir = "sm90" variant_name = "SM90 (Hopper/H100 with fast math optimization)" elif compute_capability is not None: ops_subdir = "sm100" variant_name = f"SM{compute_capability} (precise math for compatibility)" else: ops_subdir = "sm100" variant_name = "CPU/No GPU detected (using precise math)" # Look for the compiled module with any valid extension import glob ops_pattern = str(sgl_kernel_dir / ops_subdir / "common_ops.*") raw_matching_files = glob.glob(ops_pattern) matching_files = _filter_compiled_extensions(raw_matching_files) logger.debug(f"[sgl_kernel] Attempting to load {variant_name}") logger.debug(f"[sgl_kernel] Looking for library matching pattern: {ops_pattern}") logger.debug(f"[sgl_kernel] Found files: {raw_matching_files}") logger.debug(f"[sgl_kernel] Prioritized files: {matching_files}") # Try to load from the architecture-specific directory if matching_files: ops_path = Path(matching_files[0]) # Use the first prioritized file logger.debug(f"[sgl_kernel] Found architecture-specific library: {ops_path}") try: # Load the module from specific path using importlib spec = importlib.util.spec_from_file_location("common_ops", str(ops_path)) if spec is None: raise ImportError(f"Could not create module spec for {ops_path}") common_ops = importlib.util.module_from_spec(spec) if spec.loader is None: raise ImportError(f"Module spec has no loader for {ops_path}") logger.debug(f"[sgl_kernel] Loading module from {ops_path}...") spec.loader.exec_module(common_ops) logger.debug(f"[sgl_kernel] ✓ Successfully loaded {variant_name}") logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}") return common_ops except Exception as e: logger.debug( f"[sgl_kernel] ✗ Failed to load from {ops_path}: {type(e).__name__}: {e}" ) # Continue to fallback else: logger.debug( f"[sgl_kernel] ✗ Architecture-specific library not found matching pattern: {ops_pattern}" ) # Try alternative directory (in case installation structure differs) alt_pattern = str(sgl_kernel_dir / "common_ops.*") raw_alt_files = glob.glob(alt_pattern) alt_matching_files = _filter_compiled_extensions(raw_alt_files) logger.debug(f"[sgl_kernel] Attempting fallback: looking for pattern {alt_pattern}") logger.debug(f"[sgl_kernel] Found fallback files: {raw_alt_files}") logger.debug(f"[sgl_kernel] Prioritized fallback files: {alt_matching_files}") if alt_matching_files: alt_path = Path(alt_matching_files[0]) # Use the first prioritized file logger.debug(f"[sgl_kernel] Found fallback library: {alt_path}") try: spec = importlib.util.spec_from_file_location("common_ops", str(alt_path)) if spec is None: raise ImportError(f"Could not create module spec for {alt_path}") common_ops = importlib.util.module_from_spec(spec) if spec.loader is None: raise ImportError(f"Module spec has no loader for {alt_path}") logger.debug(f"[sgl_kernel] Loading fallback module from {alt_path}...") spec.loader.exec_module(common_ops) logger.debug(f"[sgl_kernel] ✓ Successfully loaded fallback library") logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}") return common_ops except Exception as e: logger.debug( f"[sgl_kernel] ✗ Failed to load fallback from {alt_path}: {type(e).__name__}: {e}" ) else: logger.debug( f"[sgl_kernel] ✗ Fallback library not found matching pattern: {alt_pattern}" ) # Final attempt: try standard Python import (for backward compatibility) logger.debug( f"[sgl_kernel] Final attempt: trying standard Python import 'common_ops'" ) try: import common_ops logger.debug(f"[sgl_kernel] ✓ Successfully imported via standard Python import") logger.debug(f"[sgl_kernel] ✓ Module file: {common_ops.__file__}") return common_ops except ImportError as e: logger.debug(f"[sgl_kernel] ✗ Standard Python import failed: {e}") # All attempts failed error_msg = f""" [sgl_kernel] CRITICAL: Could not load any common_ops library! Attempted locations: 1. Architecture-specific pattern: {ops_pattern} - found files: {matching_files} 2. Fallback pattern: {alt_pattern} - found files: {alt_matching_files} 3. Standard Python import: common_ops - failed GPU Info: - Compute capability: {compute_capability} - Expected variant: {variant_name} Please ensure sgl_kernel is properly installed with: pip install --upgrade sgl_kernel """ logger.debug(error_msg) raise ImportError(error_msg) # Initialize the ops library based on current GPU logger.debug("[sgl_kernel] Initializing architecture-specific operator library...") common_ops = _load_architecture_specific_ops() logger.debug("[sgl_kernel] ✓ Operator library initialization complete") # copy & modify from torch/utils/cpp_extension.py def _find_cuda_home(): """Find the CUDA install path.""" # Guess #1 cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") if cuda_home is None: # Guess #2 nvcc_path = shutil.which("nvcc") if nvcc_path is not None: cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) else: # Guess #3 cuda_home = "/usr/local/cuda" return cuda_home if torch.version.cuda is not None: cuda_home = Path(_find_cuda_home()) if (cuda_home / "lib").is_dir(): cuda_path = cuda_home / "lib" elif (cuda_home / "lib64").is_dir(): cuda_path = cuda_home / "lib64" else: # Search for 'libcudart.so.12' in subdirectories for path in cuda_home.rglob("libcudart.so.12"): cuda_path = path.parent break else: raise RuntimeError("Could not find CUDA lib directory.") cuda_include = (cuda_path / "libcudart.so.12").resolve() if cuda_include.exists(): ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL) from sgl_kernel.allreduce import * from sgl_kernel.attention import ( cutlass_mla_decode, cutlass_mla_get_workspace_size, lightning_attention_decode, merge_state, merge_state_v2, ) from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data from sgl_kernel.elementwise import ( FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace, concat_mla_absorb_q, concat_mla_k, copy_to_gpu_no_ce, downcast_fp8, fused_add_rmsnorm, gelu_and_mul, gelu_tanh_and_mul, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm, silu_and_mul, ) from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, cutlass_scaled_fp4_mm, dsv3_fused_a_gemm, dsv3_router_gemm, fp8_blockwise_scaled_mm, fp8_scaled_mm, gptq_gemm, gptq_marlin_gemm, gptq_shuffle, int8_scaled_mm, qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm, scaled_fp4_experts_quant, scaled_fp4_grouped_quant, scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, shuffle_rows, silu_and_mul_scaled_fp4_grouped_quant, ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_mla, transfer_kv_per_layer, transfer_kv_per_layer_mla, ) from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update from sgl_kernel.marlin import ( awq_marlin_moe_repack, awq_marlin_repack, gptq_marlin_repack, ) from sgl_kernel.memory import set_kv_buffer_kernel from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, fp8_blockwise_scaled_grouped_mm, moe_align_block_size, moe_fused_gate, moe_sum_reduce, prepare_moe_input, topk_softmax, ) from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_mask_logits, top_k_renorm_prob, top_k_top_p_sampling_from_logits, top_k_top_p_sampling_from_probs, top_p_renorm_prob, top_p_sampling_from_probs, ) from sgl_kernel.speculative import ( build_tree_kernel_efficient, reconstruct_indices_from_tree_mask, segment_packbits, tree_speculative_sampling_target_only, verify_tree_greedy, ) from sgl_kernel.top_k import fast_topk from sgl_kernel.version import __version__ if torch.version.hip is not None: from sgl_kernel.elementwise import gelu_quick def create_greenctx_stream_by_value(*args, **kwargs): from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl return _impl(*args, **kwargs) def get_sm_available(*args, **kwargs): from sgl_kernel.spatial import get_sm_available as _impl return _impl(*args, **kwargs)