"tools/python/vscode:/vscode.git/clone" did not exist on "2a26521b0be75c5a68f3d4c28898a119347cc597"
Unverified Commit 621dfb88 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Import flash_mla from sgl-kernel (#12135)

parent fb52d35f
...@@ -747,7 +747,7 @@ jobs: ...@@ -747,7 +747,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} RUN_DEEPSEEK_V32=1 bash scripts/ci/ci_install_dependency.sh CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
- name: Run test - name: Run test
timeout-minutes: 20 timeout-minutes: 20
......
...@@ -8,8 +8,6 @@ ARG GRACE_BLACKWELL=0 ...@@ -8,8 +8,6 @@ ARG GRACE_BLACKWELL=0
ARG GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2 ARG GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2
ARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee ARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee
ARG FLASHMLA_COMMIT=1408756a88e52a25196b759eaf8db89d2b51b5a1
ARG TRITON_LANG_COMMIT=4caa0328bf8df64896dd5f6fb9df41b0eb2e750a ARG TRITON_LANG_COMMIT=4caa0328bf8df64896dd5f6fb9df41b0eb2e750a
ARG SGL_KERNEL_VERSION=0.3.16.post4 ARG SGL_KERNEL_VERSION=0.3.16.post4
...@@ -179,17 +177,6 @@ RUN cd /sgl-workspace/DeepEP && \ ...@@ -179,17 +177,6 @@ RUN cd /sgl-workspace/DeepEP && \
fi && \ fi && \
NVSHMEM_DIR=${NVSHMEM_DIR} TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install --no-build-isolation . NVSHMEM_DIR=${NVSHMEM_DIR} TORCH_CUDA_ARCH_LIST="${CHOSEN_TORCH_CUDA_ARCH_LIST}" pip install --no-build-isolation .
# Install flashmla
RUN if [ "$CUDA_VERSION" != "13.0.1" ]; then \
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \
cd flash-mla && \
git checkout ${FLASHMLA_COMMIT} && \
git submodule update --init --recursive && \
if [ "$CUDA_VERSION" = "12.6.1" ]; then \
export FLASH_MLA_DISABLE_SM100=1; \
fi && \
pip install --no-build-isolation -v . ; \
fi
# In order to use flashinfer_cutedsl without IMA for WideEP configs we must install # In order to use flashinfer_cutedsl without IMA for WideEP configs we must install
# latest flashinfer_cutedsl. Once 0.4.3 is officially released, remove this # latest flashinfer_cutedsl. Once 0.4.3 is officially released, remove this
......
...@@ -27,13 +27,7 @@ docker pull lmsysorg/sglang:dsv32-a3 ...@@ -27,13 +27,7 @@ docker pull lmsysorg/sglang:dsv32-a3
git clone https://github.com/sgl-project/sglang git clone https://github.com/sgl-project/sglang
cd sglang cd sglang
pip3 install pip --upgrade pip3 install pip --upgrade
pip3 install -e "python[all]" pip3 install -e "python"
# Install flash_mla
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
``` ```
## Launch DeepSeek V3.2 with SGLang ## Launch DeepSeek V3.2 with SGLang
......
...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union ...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
import torch import torch
import triton import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata from sgl_kernel.flash_mla import flash_mla_with_kvcache, get_mla_metadata
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
......
...@@ -1098,7 +1098,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -1098,7 +1098,7 @@ class NativeSparseAttnBackend(AttentionBackend):
page_table_1: torch.Tensor, page_table_1: torch.Tensor,
sm_scale: float, sm_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
from flash_mla import flash_mla_sparse_fwd from sgl_kernel.flash_mla import flash_mla_sparse_fwd
o, _, _ = flash_mla_sparse_fwd( o, _, _ = flash_mla_sparse_fwd(
q=q_all, q=q_all,
...@@ -1119,7 +1119,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -1119,7 +1119,7 @@ class NativeSparseAttnBackend(AttentionBackend):
metadata: NSAMetadata, metadata: NSAMetadata,
page_table_1, page_table_1,
) -> torch.Tensor: ) -> torch.Tensor:
from flash_mla import flash_mla_with_kvcache from sgl_kernel.flash_mla import flash_mla_with_kvcache
cache_seqlens = metadata.nsa_cache_seqlens_int32 cache_seqlens = metadata.nsa_cache_seqlens_int32
...@@ -1261,7 +1261,7 @@ class NativeSparseAttnBackend(AttentionBackend): ...@@ -1261,7 +1261,7 @@ class NativeSparseAttnBackend(AttentionBackend):
) )
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int): def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
from flash_mla import get_mla_metadata from sgl_kernel.flash_mla import get_mla_metadata
flashmla_metadata, num_splits = get_mla_metadata( flashmla_metadata, num_splits = get_mla_metadata(
cache_seqlens=cache_seqlens, cache_seqlens=cache_seqlens,
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
set -euxo pipefail set -euxo pipefail
IS_BLACKWELL=${IS_BLACKWELL:-0} IS_BLACKWELL=${IS_BLACKWELL:-0}
RUN_DEEPSEEK_V32=${RUN_DEEPSEEK_V32:-0}
CU_VERSION="cu129" CU_VERSION="cu129"
if [ "$CU_VERSION" = "cu130" ]; then if [ "$CU_VERSION" = "cu130" ]; then
...@@ -113,22 +112,6 @@ if [ "$IS_BLACKWELL" != "1" ]; then ...@@ -113,22 +112,6 @@ if [ "$IS_BLACKWELL" != "1" ]; then
$PIP_CMD install xformers --index-url https://download.pytorch.org/whl/${CU_VERSION} --no-deps $PIP_INSTALL_SUFFIX $PIP_CMD install xformers --index-url https://download.pytorch.org/whl/${CU_VERSION} --no-deps $PIP_INSTALL_SUFFIX
fi fi
# Install dependencies for deepseek-v3.2
if [ "$RUN_DEEPSEEK_V32" = "1" ]; then
# Install flashmla
FLASHMLA_COMMIT="1408756a88e52a25196b759eaf8db89d2b51b5a1"
FLASH_MLA_DISABLE_SM100="0"
if [ "$IS_BLACKWELL" != "1" ]; then
FLASH_MLA_DISABLE_SM100="1"
fi
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git checkout ${FLASHMLA_COMMIT}
git submodule update --init --recursive
FLASH_MLA_DISABLE_SM100=${FLASH_MLA_DISABLE_SM100} $PIP_CMD install -v . $PIP_INSTALL_SUFFIX --no-build-isolation
cd ..
fi
# Show current packages # Show current packages
$PIP_CMD list $PIP_CMD list
python3 -c "import torch; print(torch.version.cuda)" python3 -c "import torch; print(torch.version.cuda)"
...@@ -82,7 +82,7 @@ suites = { ...@@ -82,7 +82,7 @@ suites = {
TestFile("test_ebnf_constrained.py", 108), TestFile("test_ebnf_constrained.py", 108),
TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_eval_fp8_accuracy.py", 303),
TestFile("test_fa3.py", 376), TestFile("test_fa3.py", 376),
# TestFile("test_flashmla.py", 352), TestFile("test_flashmla.py", 352),
TestFile("rotary_embedding/test_mrope.py", 300), TestFile("rotary_embedding/test_mrope.py", 300),
TestFile("test_function_call_parser.py", 10), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30), TestFile("test_fused_moe.py", 30),
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.send_one import BenchArgs, send_one_prompt
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -18,7 +19,7 @@ from sglang.test.test_utils import ( ...@@ -18,7 +19,7 @@ from sglang.test.test_utils import (
CustomTestCase, CustomTestCase,
is_in_ci, is_in_ci,
popen_launch_server, popen_launch_server,
run_bench_one_batch, write_github_step_summary,
) )
...@@ -31,7 +32,6 @@ class TestFlashMLAAttnBackend(unittest.TestCase): ...@@ -31,7 +32,6 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
if torch.cuda.is_available() and torch.version.cuda: if torch.cuda.is_available() and torch.version.cuda:
other_args.extend( other_args.extend(
[ [
"--enable-torch-compile",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "2",
"--attention-backend", "--attention-backend",
...@@ -65,24 +65,6 @@ class TestFlashMLAAttnBackend(unittest.TestCase): ...@@ -65,24 +65,6 @@ class TestFlashMLAAttnBackend(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.60) self.assertGreater(metrics["accuracy"], 0.60)
class TestFlashMLAAttnLatency(unittest.TestCase):
def test_latency(self):
_, output_throughput, _ = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
[
"--attention-backend",
"flashmla",
"--enable-torch-compile",
"--cuda-graph-max-bs",
"16",
"--trust-remote-code",
],
)
if is_in_ci():
self.assertGreater(output_throughput, 100)
class TestFlashMLAMTP(CustomTestCase): class TestFlashMLAMTP(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment