Unverified Commit ad1ae7f7 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

use topk_softmax with sgl-kernel (#4439)

parent e73167ad
...@@ -20,7 +20,7 @@ jobs: ...@@ -20,7 +20,7 @@ jobs:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
......
...@@ -17,7 +17,7 @@ jobs: ...@@ -17,7 +17,7 @@ jobs:
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: | run: |
......
...@@ -6,7 +6,7 @@ jobs: ...@@ -6,7 +6,7 @@ jobs:
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
......
...@@ -20,7 +20,7 @@ jobs: ...@@ -20,7 +20,7 @@ jobs:
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: | run: |
......
...@@ -25,7 +25,7 @@ jobs: ...@@ -25,7 +25,7 @@ jobs:
runs-on: linux-mi300-gpu-1 runs-on: linux-mi300-gpu-1
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Setup docker - name: Setup docker
run: | run: |
...@@ -64,7 +64,7 @@ jobs: ...@@ -64,7 +64,7 @@ jobs:
runs-on: linux-mi300-gpu-1 runs-on: linux-mi300-gpu-1
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Setup docker - name: Setup docker
run: | run: |
......
...@@ -21,7 +21,7 @@ jobs: ...@@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
run: | run: |
...@@ -45,7 +45,7 @@ jobs: ...@@ -45,7 +45,7 @@ jobs:
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install rust dependencies - name: Install rust dependencies
run: | run: |
......
...@@ -20,7 +20,7 @@ jobs: ...@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Check clang-format - name: Check clang-format
uses: DoozyX/clang-format-lint-action@v0.18.1 uses: DoozyX/clang-format-lint-action@v0.18.1
......
...@@ -39,7 +39,7 @@ jobs: ...@@ -39,7 +39,7 @@ jobs:
run_tests: ${{ steps.set_run_tests.outputs.run_tests }} run_tests: ${{ steps.set_run_tests.outputs.run_tests }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Filter changes - name: Filter changes
id: filter id: filter
uses: dorny/paths-filter@v2 uses: dorny/paths-filter@v2
...@@ -72,7 +72,7 @@ jobs: ...@@ -72,7 +72,7 @@ jobs:
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -98,7 +98,7 @@ jobs: ...@@ -98,7 +98,7 @@ jobs:
part: [0, 1, 2, 3, 4, 5, 6] part: [0, 1, 2, 3, 4, 5, 6]
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -120,7 +120,7 @@ jobs: ...@@ -120,7 +120,7 @@ jobs:
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -172,7 +172,7 @@ jobs: ...@@ -172,7 +172,7 @@ jobs:
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -218,7 +218,7 @@ jobs: ...@@ -218,7 +218,7 @@ jobs:
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -252,7 +252,7 @@ jobs: ...@@ -252,7 +252,7 @@ jobs:
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -294,7 +294,7 @@ jobs: ...@@ -294,7 +294,7 @@ jobs:
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
...@@ -319,7 +319,7 @@ jobs: ...@@ -319,7 +319,7 @@ jobs:
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install dependencies - name: Install dependencies
env: env:
......
...@@ -23,7 +23,7 @@ jobs: ...@@ -23,7 +23,7 @@ jobs:
build_type: ['all', 'srt'] build_type: ['all', 'srt']
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: "Set Date" - name: "Set Date"
run: | run: |
......
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
build_type: ['all', 'srt'] build_type: ['all', 'srt']
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Free disk space - name: Free disk space
uses: jlumbroso/free-disk-space@main uses: jlumbroso/free-disk-space@main
......
...@@ -10,7 +10,7 @@ jobs: ...@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Free disk space - name: Free disk space
uses: jlumbroso/free-disk-space@main uses: jlumbroso/free-disk-space@main
......
...@@ -21,7 +21,7 @@ jobs: ...@@ -21,7 +21,7 @@ jobs:
run: rm -rf /opt/hostedtoolcache run: rm -rf /opt/hostedtoolcache
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
......
...@@ -20,7 +20,7 @@ jobs: ...@@ -20,7 +20,7 @@ jobs:
if: github.repository == 'sgl-project/sglang' if: github.repository == 'sgl-project/sglang'
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
......
...@@ -17,7 +17,7 @@ jobs: ...@@ -17,7 +17,7 @@ jobs:
environment: 'prod' environment: 'prod'
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Get version - name: Get version
id: get_version id: get_version
......
...@@ -19,7 +19,7 @@ jobs: ...@@ -19,7 +19,7 @@ jobs:
python-version: '3.9' python-version: '3.9'
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Upload to pypi - name: Upload to pypi
run: | run: |
......
...@@ -43,7 +43,7 @@ runtime_common = [ ...@@ -43,7 +43,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.0.5", "sgl-kernel==0.0.5.post1",
"flashinfer_python==0.2.3", "flashinfer_python==0.2.3",
"torch==2.5.1", "torch==2.5.1",
"vllm>=0.6.4.post1,<=0.7.2", "vllm>=0.6.4.post1,<=0.7.2",
......
...@@ -17,7 +17,9 @@ from typing import Callable, Optional ...@@ -17,7 +17,9 @@ from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import get_compiler_backend, is_cuda
_is_cuda = is_cuda()
def fused_topk_native( def fused_topk_native(
...@@ -47,6 +49,9 @@ def fused_topk( ...@@ -47,6 +49,9 @@ def fused_topk(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
): ):
if _is_cuda:
from sgl_kernel import topk_softmax
else:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
...@@ -61,6 +66,14 @@ def fused_topk( ...@@ -61,6 +66,14 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
if _is_cuda:
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
else:
ops.topk_softmax( ops.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
......
...@@ -26,4 +26,4 @@ pip install transformers==4.45.2 sentence_transformers accelerate==1.4.0 peft pa ...@@ -26,4 +26,4 @@ pip install transformers==4.45.2 sentence_transformers accelerate==1.4.0 peft pa
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
# reinstall sgl-kernel # reinstall sgl-kernel
pip install sgl-kernel==0.0.5 --force-reinstall --no-deps pip install sgl-kernel==0.0.5.post1 --force-reinstall --no-deps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment