Unverified Commit 977d7cd2 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

cleanup deps 1/n (#4400)


Co-authored-by: default avatarsleepcoo <sleepcoo@gmail.com>
parent 0e0ec702
...@@ -45,6 +45,7 @@ jobs: ...@@ -45,6 +45,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip install --upgrade pip
docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
docker exec ci_sglang pip install -e "python[dev_hip]" docker exec ci_sglang pip install -e "python[dev_hip]"
...@@ -83,6 +84,7 @@ jobs: ...@@ -83,6 +84,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip install --upgrade pip
docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
docker exec ci_sglang pip install -e "python[dev_hip]" docker exec ci_sglang pip install -e "python[dev_hip]"
......
...@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple ...@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm import _custom_ops as vllm_ops
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -32,6 +31,8 @@ _is_cuda = is_cuda() ...@@ -32,6 +31,8 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple ...@@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as vllm_ops
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
...@@ -46,6 +45,8 @@ if _is_cuda: ...@@ -46,6 +45,8 @@ if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
else:
from vllm import _custom_ops as vllm_ops
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
...@@ -456,44 +457,34 @@ def moe_align_block_size( ...@@ -456,44 +457,34 @@ def moe_align_block_size(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
if num_experts >= 224: if enable_moe_align_block_size_triton:
if enable_moe_align_block_size_triton: moe_align_block_size_triton(
moe_align_block_size_triton( topk_ids,
topk_ids, num_experts,
num_experts, block_size,
block_size, sorted_ids,
sorted_ids, expert_ids,
expert_ids, num_tokens_post_pad,
num_tokens_post_pad, )
)
else:
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
)
else: else:
vllm_ops.moe_align_block_size( token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
sorted_ids, sorted_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
) )
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
......
...@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
...@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available ...@@ -14,6 +13,8 @@ from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available() _is_cuda_available = is_cuda_available()
if _is_cuda_available: if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
from vllm import _custom_ops as ops
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment