"src/runtime/vscode:/vscode.git/clone" did not exist on "115ac0b9a3dbd806cc52f2a428048b79502f2350"
Unverified Commit 13bf565d authored by Jinwu's avatar Jinwu Committed by GitHub
Browse files

[2/N]Support DeepSeek-R1 w4a8 low latency deepep (#8464)


Co-authored-by: default avatarHank Han <hanhan7630@outlook.com>
Co-authored-by: default avatarShangchuan Huang <2510421000@qq.com>
parent e51046be
......@@ -11,12 +11,14 @@ from sgl_kernel import (
)
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_ll_get_cutlass_w4a8_moe_mm_data,
deepep_permute_triton_kernel,
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe,
run_moe_ep_preproess,
silu_and_mul_masked_post_per_tensor_quant_fwd,
)
......@@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal(
)
return output
def cutlass_w4a8_moe_deepep_ll(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_ids_: torch.Tensor,
masked_m: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
a_strides2: torch.Tensor,
b_strides2: torch.Tensor,
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
Shape: [num_experts, N * 2, K // 2]
(the weights are passed transposed and int4-packed)
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
Shape: [num_experts, K, N // 2]
(the weights are passed transposed and int4-packed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts, K // 512, N * 8]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [1, K]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [1, N]
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
m = a.size(1)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)
device = a.device
problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
masked_m,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
)
gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)
cutlass_w4a8_moe_mm(
c1,
gateup_input,
w1_q,
a1_scale.float(),
w1_scale,
expert_offsets[:-1],
problem_sizes1,
a_strides1,
b_strides1,
c_strides1,
s_strides13,
128,
topk,
)
intermediate_q = torch.empty(
(num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
)
silu_and_mul_masked_post_per_tensor_quant_fwd(
c1, intermediate_q, masked_m, a2_scale
)
cutlass_w4a8_moe_mm(
c2,
intermediate_q,
w2_q,
a2_scale.float(),
w2_scale,
expert_offsets[:-1],
problem_sizes2,
a_strides2,
b_strides2,
c_strides2,
s_strides2,
128,
topk,
)
return c2
......@@ -1014,3 +1014,197 @@ def zero_experts_compute_triton(
)
return output
@triton.jit
def compute_problem_sizes_w4a8_kernel(
masked_m_ptr,
problem_sizes1_ptr,
problem_sizes2_ptr,
n,
k,
num_experts,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pid < num_experts
final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
ps1_idx_0 = pid * 3
ps1_idx_1 = ps1_idx_0 + 1
ps1_idx_2 = ps1_idx_0 + 2
ps2_idx_0 = pid * 3
ps2_idx_1 = ps2_idx_0 + 1
ps2_idx_2 = ps2_idx_0 + 2
ps1_mask_0 = ps1_idx_0 < num_experts * 3
ps1_mask_1 = ps1_idx_1 < num_experts * 3
ps1_mask_2 = ps1_idx_2 < num_experts * 3
ps2_mask_0 = ps2_idx_0 < num_experts * 3
ps2_mask_1 = ps2_idx_1 < num_experts * 3
ps2_mask_2 = ps2_idx_2 < num_experts * 3
tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
def compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
):
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
compute_problem_sizes_w4a8_kernel[grid](
masked_m,
problem_sizes1,
problem_sizes2,
n,
k,
num_experts,
BLOCK_SIZE=BLOCK_SIZE,
)
return problem_sizes1, problem_sizes2
def deepep_ll_get_cutlass_w4a8_moe_mm_data(
masked_m,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
):
problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
)
return (
problem_sizes1.to(torch.int32),
problem_sizes2.to(torch.int32),
)
@triton.jit
def _silu_and_mul_post_per_tensor_quant_kernel(
input_ptr,
stride_input_expert,
stride_input_token,
stride_input_dim,
output_ptr,
stride_output_expert,
stride_output_token,
stride_output_dim,
scale_ptr,
masked_m_ptr,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
"""
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
output: [E, T_padded, D], dtype=float8_e4m3fn
"""
expert_id = tl.program_id(2)
block_id_token = tl.program_id(1)
block_id_dim = tl.program_id(0)
num_token_blocks = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
stride_input_expert = tl.cast(stride_input_expert, tl.int32)
stride_output_expert = tl.cast(stride_output_expert, tl.int32)
stride_input_token = tl.cast(stride_input_token, tl.int32)
stride_output_token = tl.cast(stride_output_token, tl.int32)
offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
mask_d = offset_d < inner_dim
# base pointers for current expert and dim block
input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
for token_idx in tl.range(
block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
):
gate_ptr = input_base_offs + token_idx * stride_input_token
up_ptr = gate_ptr + inner_dim
gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
# SiLU: x * sigmoid(x)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
scaled = gate_up * scale
output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
out_ptr = output_base_offs + token_idx * stride_output_token
tl.store(out_ptr, output_q, mask=mask_d)
def silu_and_mul_masked_post_per_tensor_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
masked_m: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""
Fused SiLU + Mul + Per-Tensor Quantization to FP8.
Args:
input: [expert_num, token_num_padded, 2 * inner_dim]
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
masked_m: [expert_num], actual token count for each expert
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
Returns:
output tensor
"""
assert input.is_contiguous()
assert output.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert input.ndim == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
expert_num = input.shape[0]
# 3584
inner_dim = input.shape[-1] // 2
BLOCK_N = 256
BLOCK_M = 64 if expert_num < 4 else 32
NUM_STAGES = 3
hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_per_tensor_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
scale,
masked_m,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
)
return output
......@@ -100,6 +100,7 @@ class DeepEPMoE(FusedMoE):
self.use_fp8_w8a8 = False
self.use_block_quant = False
else:
self.use_w4afp8 = False
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.use_w4afp8 = False
......@@ -199,6 +200,8 @@ class DeepEPMoE(FusedMoE):
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
elif self.use_w4afp8:
return self.forward_cutlass_w4afp8_masked(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
......@@ -514,6 +517,20 @@ class DeepEPMoE(FusedMoE):
return down_output
def forward_cutlass_w4afp8_masked(
self,
dispatch_output: DeepEPNormalOutput,
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
assert get_bool_env_var(
"SGLANG_DEEPEP_BF16_DISPATCH"
), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1."
return self.quant_method.apply_deepep_ll(
layer=self,
dispatch_output=dispatch_output,
)
def forward_npu(
self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
......
......@@ -23,6 +23,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
DeepEPLLOutput,
DeepEPNormalOutput,
StandardDispatchOutput,
)
......@@ -328,6 +329,41 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
output *= self.moe_runner_config.routed_scaling_factor
return StandardCombineInput(hidden_states=output)
def apply_deepep_ll(
self,
layer: DeepEPMoE,
dispatch_output: DeepEPLLOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll
hidden_states, _, topk_ids, _, masked_m, _ = dispatch_output
output = cutlass_w4a8_moe_deepep_ll(
hidden_states,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
topk_ids,
masked_m,
layer.quant_method.a_strides1,
layer.quant_method.b_strides1,
layer.quant_method.c_strides1,
layer.quant_method.a_strides2,
layer.quant_method.b_strides2,
layer.quant_method.c_strides2,
layer.quant_method.s_strides13,
layer.quant_method.s_strides2,
layer.quant_method.expert_offsets,
layer.quant_method.problem_sizes1,
layer.quant_method.problem_sizes2,
layer.w13_input_scale,
layer.w2_input_scale,
)
return output
def apply_deepep_normal(
self,
layer: DeepEPMoE,
......
......@@ -34,6 +34,40 @@ __global__ void int4_fp8_get_group_gemm_starts(
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * k / 128 : expert_id);
}
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
__global__ void int4_fp8_get_group_gemm_starts_3d(
ElementA** a_offsets,
ElementB** b_offsets,
ElementC** out_offsets,
ElementAccumulator** a_scales_offsets,
cutlass::bfloat16_t** b_scales_offsets,
ElementA* a_base_as_int,
ElementB* b_base_as_int,
ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
cutlass::bfloat16_t* b_scales_base_as_int,
int64_t n,
int64_t m,
int64_t k,
bool per_act_token,
bool per_out_ch,
int num_experts) {
int expert_id = blockIdx.x * blockDim.x + threadIdx.x;
if (expert_id >= num_experts) return;
int64_t a_offset = expert_id * m * k;
int64_t b_offset = expert_id * k * n / 2;
int64_t out_offset = expert_id * m * n;
int64_t a_scales_offset = 0;
int64_t b_scales_offset = per_out_ch ? expert_id * n * 4 * k / 512 : expert_id;
a_offsets[expert_id] = a_base_as_int + a_offset;
b_offsets[expert_id] = b_base_as_int + b_offset;
out_offsets[expert_id] = out_base_as_int + out_offset;
a_scales_offsets[expert_id] = a_scales_base_as_int + a_scales_offset;
b_scales_offsets[expert_id] = b_scales_base_as_int + b_scales_offset;
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
......@@ -55,6 +89,28 @@ __global__ void int4_fp8_get_group_gemm_starts(
per_out_ch); \
}
#define __CALL_W4A8_GET_STARTS_KERNEL_3D(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
int4_fp8_get_group_gemm_starts_3d<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
out_tensors.size(2), \
a_tensors.size(1), \
a_tensors.size(2), \
per_act_token, \
per_out_ch, \
num_experts); \
}
namespace {
void run_int4_fp8_get_group_gemm_starts(
......@@ -80,12 +136,22 @@ void run_int4_fp8_get_group_gemm_starts(
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
if (a_tensors.dim() == 3) {
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL_3D(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL_3D(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
} else {
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
}
......
......@@ -174,7 +174,7 @@ void cutlass_w4a8_group_gemm_caller(
bool per_out_ch = b_scales.numel() != num_experts;
// Check inputs
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
TORCH_CHECK(a_tensors.dim() == 2 or a_tensors.dim() == 3, "A tensor must be 2D/3D");
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
......@@ -186,7 +186,9 @@ void cutlass_w4a8_group_gemm_caller(
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
TORCH_CHECK(
b_tensors.size(2) * 2 == a_tensors.size(1) or b_tensors.size(2) * 2 == a_tensors.size(2),
"B tensor K/2 dimension must match A tensor K dimension");
// Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
......
import os
import unittest
from types import SimpleNamespace
......@@ -173,5 +174,73 @@ class TestDeepseekV3W4Afp8DeepepNormal(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.92)
class TestDeepseekV3W4Afp8DeepepAutoMtp(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = try_cached_model(DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST)
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--tp",
"8",
"--trust-remote-code",
"--ep-size",
"8",
"--cuda-graph-bs",
"256",
"--disable-radix-cache",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"auto",
"--dp",
"8",
"--enable-dp-attention",
"--moe-runner-backend",
"cutlass",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
]
if not is_in_amd_ci():
other_args += ["--mem-frac", "0.7"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
env={
**os.environ,
"SGLANG_DEEPEP_BF16_DISPATCH": "1",
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256",
},
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(
self,
):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.92)
if __name__ == "__main__":
unittest.main()
......@@ -172,7 +172,7 @@ suites = {
TestFile("test_disaggregation_hybrid_attention.py", 200),
],
"per-commit-8-gpu-h20": [
TestFile("quant/test_w4a8_deepseek_v3.py", 371),
TestFile("quant/test_w4a8_deepseek_v3.py", 520),
TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_pp.py", 140),
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment