Unverified Commit 5c0b38f3 authored by HAI's avatar HAI Committed by GitHub
Browse files

aiter attention-backend (default enabled on AMD/ROCm) (#6381)

parent 30ca18f4
......@@ -44,7 +44,7 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Evaluate Accuracy
timeout-minutes: 20
timeout-minutes: 30
run: |
bash scripts/amd_ci_exec.sh python3 test_eval_accuracy_large.py
bash scripts/amd_ci_exec.sh python3 test_eval_fp8_accuracy.py
......@@ -70,7 +70,7 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Evaluate accuracy (TP=2)
timeout-minutes: 20
timeout-minutes: 30
run: |
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py
......@@ -94,7 +94,7 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: MLA TEST
timeout-minutes: 20
timeout-minutes: 30
run: |
bash scripts/amd_ci_exec.sh python3 test_mla.py
......@@ -118,28 +118,28 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Benchmark single latency
timeout-minutes: 10
timeout-minutes: 20
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_small
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_default
- name: Benchmark online latency
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_default
- name: Benchmark offline throughput
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default
- name: Benchmark offline throughput (Non-streaming, small batch size)
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
- name: Benchmark online latency (EAGLE)
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle
......@@ -163,17 +163,17 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Benchmark offline throughput (w/o RadixAttention)
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
- name: Benchmark offline throughput (w/ Triton)
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend
- name: Benchmark offline throughput (w/ FP8)
timeout-minutes: 10
timeout-minutes: 15
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8
......@@ -197,27 +197,27 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Benchmark dummy grok (TP=2)
timeout-minutes: 20
timeout-minutes: 30
run: |
bash scripts/amd_ci_exec.sh python3 models/test_dummy_grok_models.py
- name: Benchmark single latency (TP=2)
timeout-minutes: 20
timeout-minutes: 25
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
- name: Benchmark single latency + torch.compile (TP=2)
timeout-minutes: 20
timeout-minutes: 25
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1
- name: Benchmark offline throughput (TP=2)
timeout-minutes: 20
timeout-minutes: 25
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default
- name: Benchmark offline throughput (w/o RadixAttention) (TP=2)
timeout-minutes: 20
timeout-minutes: 25
run: |
bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
......@@ -241,7 +241,7 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Run test
timeout-minutes: 30
timeout-minutes: 40
run: |
bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd
......@@ -265,7 +265,7 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Run test
timeout-minutes: 30
timeout-minutes: 40
run: |
bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-2-gpu-amd
......@@ -289,7 +289,7 @@ jobs:
run: bash scripts/amd_ci_install_dependency.sh
- name: Run test
timeout-minutes: 30
timeout-minutes: 40
run: |
bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-8-gpu-amd
......
......@@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG AITER_COMMIT="v0.1.1"
ARG AITER_COMMIT="v0.1.2"
RUN git clone ${SGL_REPO} \
&& cd sglang \
......
from __future__ import annotations
"""
end to end attention solution with aiter kernels
"""
import math
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
try:
from aiter import mha_batch_prefill_func, paged_attention_ragged
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
CROSS_ATTENTION = auto()
@dataclass
class ForwardMetadata:
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
max_q_len: int
max_kv_len: int
global_workspace_buffer = None
_AITER_PARTITION_SIZE_ROCM = 256
class AiterAttnBackend(AttentionBackend):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.req_to_token = model_runner.req_to_token_pool.req_to_token
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
# Create prefill indices updater
if not skip_prefill:
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
model_runner, self
)
# aiter kernel related initialization
self.max_num_partitions = (
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
) // _AITER_PARTITION_SIZE_ROCM
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
self.device
)
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
self.device
)
self.logits_soft_cap = 0.0
self.forward_metadata: ForwardMetadata = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
# update for aiter
# create kv_indices and kv_inptr
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else:
prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal:
extend_no_prefix = False
else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_kv_indices = kv_indices_buf
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
elif forward_mode.is_target_verify():
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
kv_indptr = self.kv_indptr
kv_indices = self.cuda_graph_kv_indices
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
kv_indptr = kv_indptr[: bs + 1]
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
else:
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
self.logits_soft_cap = layer.logit_cap
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
bs0 = forward_batch.batch_size + 1
o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
v_cache,
self.qo_indptr[:bs0],
self.forward_metadata.kv_indptr[:bs0],
self.forward_metadata.kv_indices,
self.forward_metadata.max_q_len,
self.forward_metadata.max_kv_len,
causal=True,
logits_soft_cap=self.logits_soft_cap,
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
self.logits_soft_cap = layer.logit_cap
paged_attention_ragged(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_lens,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
return o
class AiterIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.update = self.update_single_wrapper
self.kv_indices = None
self.max_q_len = 0
self.max_kv_len = 0
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
kv_start_idx = None
kv_indptr = self.kv_indptr
qo_indptr = self.qo_indptr
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
bs = len(req_pool_indices)
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum + 256,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
self.max_kv_len = torch.max(paged_kernel_lens).item()
extend_lens = seq_lens - prefix_lens
self.max_q_len = torch.max(extend_lens).item()
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
)
self.kv_indices = kv_indices
......@@ -103,6 +103,8 @@ from sglang.srt.utils import (
set_cuda_arch,
)
_is_hip = is_hip()
# Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
......@@ -318,6 +320,8 @@ class ModelRunner:
and is_fa3_default_architecture(self.model_config.hf_config)
):
server_args.attention_backend = "fa3"
elif _is_hip:
server_args.attention_backend = "aiter"
else:
server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
......@@ -794,7 +798,7 @@ class ModelRunner:
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if is_hip(): # Using natively supported format
if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz
else:
self.kv_cache_dtype = torch.float8_e5m2
......@@ -972,6 +976,10 @@ class ModelRunner:
)
self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "aiter":
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
self.attn_backend = AiterAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
......
......@@ -957,6 +957,7 @@ class ServerArgs:
"--attention-backend",
type=str,
choices=[
"aiter",
"flashinfer",
"triton",
"torch_native",
......
......@@ -5,7 +5,6 @@ set -euo pipefail
docker exec ci_sglang pip install --upgrade pip
docker exec ci_sglang pip uninstall sgl-kernel -y || true
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
docker exec ci_sglang pip install -e "python[dev_hip]"
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
docker exec -w /human-eval ci_sglang pip install -e .
......
......@@ -9,7 +9,7 @@ else
fi
# Pull the image
IMAGE="lmsysorg/sglang:v0.4.6.post3-rocm630"
IMAGE="ghcr.io/saienduri/sglang-aiter-backend-v0.1.2:518"
echo "Pulling Docker image: $IMAGE"
docker pull "$IMAGE"
......
......@@ -4,6 +4,7 @@ from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch
class TestDummyGrok1(CustomTestCase):
def test_dummy_grok_1(self):
output_throughput = run_bench_one_batch(
None,
......
......@@ -3,6 +3,8 @@ Usage:
python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
"""
import os
import time
import unittest
from types import SimpleNamespace
......@@ -35,6 +37,11 @@ class TestEvalAccuracyLarge(CustomTestCase):
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def tearDown(self):
# Delay between tests to allow GPU memory cleanup
if os.getenv("SGLANG_AMD_CI") == "1":
time.sleep(180)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment