Unverified Commit 7ff740a6 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Remove dp balance metadata and minimul token balance. (#11170)

parent bfcd9b24
...@@ -812,7 +812,6 @@ def _launch_subprocesses( ...@@ -812,7 +812,6 @@ def _launch_subprocesses(
pp_rank, pp_rank,
None, None,
writer, writer,
None,
), ),
) )
......
...@@ -120,11 +120,8 @@ message GenerateRequest { ...@@ -120,11 +120,8 @@ message GenerateRequest {
// Data parallel routing // Data parallel routing
int32 data_parallel_rank = 16; int32 data_parallel_rank = 16;
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response // Whether client wants streaming response
bool stream = 18; bool stream = 17;
} }
message TokenizedInput { message TokenizedInput {
......
...@@ -82,7 +82,7 @@ class DisaggregatedParams(_message.Message): ...@@ -82,7 +82,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message): class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream") __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int] REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int] TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int] MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
...@@ -99,7 +99,6 @@ class GenerateRequest(_message.Message): ...@@ -99,7 +99,6 @@ class GenerateRequest(_message.Message):
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int] INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
LORA_ID_FIELD_NUMBER: _ClassVar[int] LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int] STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str request_id: str
tokenized: TokenizedInput tokenized: TokenizedInput
...@@ -117,9 +116,8 @@ class GenerateRequest(_message.Message): ...@@ -117,9 +116,8 @@ class GenerateRequest(_message.Message):
input_embeds: _containers.RepeatedScalarFieldContainer[float] input_embeds: _containers.RepeatedScalarFieldContainer[float]
lora_id: str lora_id: str
data_parallel_rank: int data_parallel_rank: int
dp_balance_id: int
stream: bool stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ...
class TokenizedInput(_message.Message): class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids") __slots__ = ("original_text", "input_ids")
......
...@@ -17,14 +17,11 @@ import faulthandler ...@@ -17,14 +17,11 @@ import faulthandler
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import signal import signal
import struct
import sys
import threading import threading
import time import time
from collections import deque from collections import deque
from enum import Enum, auto from enum import Enum, auto
from multiprocessing import shared_memory from typing import List
from typing import Dict, List
import psutil import psutil
import setproctitle import setproctitle
...@@ -39,7 +36,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -39,7 +36,6 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -108,15 +104,9 @@ class DPBudget: ...@@ -108,15 +104,9 @@ class DPBudget:
class DataParallelController: class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers.""" """A controller that dispatches requests to multiple data parallel workers."""
def __init__( def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
self,
server_args: ServerArgs,
port_args: PortArgs,
dp_balance_meta: DPBalanceMeta,
) -> None:
# for dp balance # for dp balance
self.global_balance_id = 0 self.global_balance_id = 0
self.balance_meta = dp_balance_meta
# Parse args # Parse args
self.max_total_num_tokens = None self.max_total_num_tokens = None
...@@ -322,7 +312,6 @@ class DataParallelController: ...@@ -322,7 +312,6 @@ class DataParallelController:
pp_rank, pp_rank,
dp_rank, dp_rank,
writer, writer,
self.balance_meta,
), ),
) )
with memory_saver_adapter.configure_subprocess(): with memory_saver_adapter.configure_subprocess():
...@@ -370,31 +359,11 @@ class DataParallelController: ...@@ -370,31 +359,11 @@ class DataParallelController:
if self.maybe_external_dp_rank_routing(req): if self.maybe_external_dp_rank_routing(req):
return return
# This variable corresponds to the balance_id in TokenizedGenerateReqInput. logger.warning(
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received). "The 'minimum_tokens' load balancing method is deprecated for now and will introduced later."
def get_next_global_balance_id() -> int: "Fall back to 'round_robin_scheduler'"
INT32_MAX = 2147483647 )
current_id = self.global_balance_id self.round_robin_scheduler(req)
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
return current_id
req.dp_balance_id = get_next_global_balance_id()
with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
local_tokens = self.balance_meta.get_shared_local_tokens()
total_tokens = [
local_token + sum(onfly_dict.values())
for local_token, onfly_dict in zip(local_tokens, onfly_info)
]
target_worker = total_tokens.index(min(total_tokens))
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
self.workers[target_worker].send_pyobj(req)
def event_loop(self): def event_loop(self):
while True: while True:
...@@ -416,12 +385,9 @@ def run_data_parallel_controller_process( ...@@ -416,12 +385,9 @@ def run_data_parallel_controller_process(
faulthandler.enable() faulthandler.enable()
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
balance_meta = DPBalanceMeta(server_args.dp_size)
try: try:
controller = DataParallelController( controller = DataParallelController(server_args, port_args)
server_args, port_args, dp_balance_meta=balance_meta
)
pipe_writer.send( pipe_writer.send(
{ {
"status": "ready", "status": "ready",
...@@ -440,6 +406,3 @@ def run_data_parallel_controller_process( ...@@ -440,6 +406,3 @@ def run_data_parallel_controller_process(
traceback = get_exception_traceback() traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}") logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT) parent_process.send_signal(signal.SIGQUIT)
finally:
# we need to destruct mp.Manager() in balance_meta
balance_meta.destructor()
...@@ -606,9 +606,6 @@ class TokenizedGenerateReqInput: ...@@ -606,9 +606,6 @@ class TokenizedGenerateReqInput:
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
...@@ -778,8 +775,6 @@ class TokenizedEmbeddingReqInput: ...@@ -778,8 +775,6 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams sampling_params: SamplingParams
# For data parallel rank routing # For data parallel rank routing
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
......
...@@ -145,7 +145,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import ( ...@@ -145,7 +145,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
from sglang.srt.managers.session_controller import Session from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
...@@ -271,7 +271,6 @@ class Scheduler( ...@@ -271,7 +271,6 @@ class Scheduler(
moe_ep_rank: int, moe_ep_rank: int,
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
...@@ -600,7 +599,6 @@ class Scheduler( ...@@ -600,7 +599,6 @@ class Scheduler(
# Init metrics stats # Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_dp_balance(dp_balance_meta)
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
self.init_kv_events(server_args.kv_events_config) self.init_kv_events(server_args.kv_events_config)
...@@ -1270,8 +1268,6 @@ class Scheduler( ...@@ -1270,8 +1268,6 @@ class Scheduler(
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
self.maybe_update_dp_balance_data(recv_req)
# Create a new request # Create a new request
if ( if (
recv_req.session_params is None recv_req.session_params is None
...@@ -1797,7 +1793,6 @@ class Scheduler( ...@@ -1797,7 +1793,6 @@ class Scheduler(
# Handle DP attention # Handle DP attention
if need_dp_attn_preparation: if need_dp_attn_preparation:
self.maybe_handle_dp_balance_data()
ret = self.prepare_mlp_sync_batch(ret) ret = self.prepare_mlp_sync_batch(ret)
return ret return ret
...@@ -2803,7 +2798,6 @@ def run_scheduler_process( ...@@ -2803,7 +2798,6 @@ def run_scheduler_process(
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
): ):
# Generate the logger prefix # Generate the logger prefix
prefix = "" prefix = ""
...@@ -2852,7 +2846,6 @@ def run_scheduler_process( ...@@ -2852,7 +2846,6 @@ def run_scheduler_process(
moe_ep_rank, moe_ep_rank,
pp_rank, pp_rank,
dp_rank, dp_rank,
dp_balance_meta=balance_meta,
) )
pipe_writer.send( pipe_writer.send(
{ {
......
...@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode ...@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_policy import PrefillAdder from sglang.srt.managers.schedule_policy import PrefillAdder
from sglang.srt.managers.scheduler import Req, ScheduleBatch from sglang.srt.managers.scheduler import Req, ScheduleBatch
from sglang.srt.managers.utils import DPBalanceMeta
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
...@@ -64,16 +63,6 @@ class SchedulerMetricsMixin: ...@@ -64,16 +63,6 @@ class SchedulerMetricsMixin:
labels["dp_rank"] = dp_rank labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels) self.metrics_collector = SchedulerMetricsCollector(labels=labels)
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
self.balance_meta = dp_balance_meta
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]): def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
...@@ -319,91 +308,6 @@ class SchedulerMetricsMixin: ...@@ -319,91 +308,6 @@ class SchedulerMetricsMixin:
batch = KVEventBatch(ts=time.time(), events=events) batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch) self.kv_event_publisher.publish(batch)
def maybe_update_dp_balance_data(
self: Scheduler, recv_req: TokenizedGenerateReqInput
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
def maybe_handle_dp_balance_data(self: Scheduler):
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
holding_tokens = self.get_load().num_tokens
new_recv_dp_balance_id_list, holding_token_list = (
self.gather_dp_balance_info(holding_tokens)
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
self.write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
def gather_dp_balance_info(
self: Scheduler, holding_tokens_list
) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
def calculate_utilization(self): def calculate_utilization(self):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.stats.utilization = -1 self.stats.utilization = -1
......
...@@ -96,46 +96,3 @@ def get_logprob_from_pp_outputs( ...@@ -96,46 +96,3 @@ def get_logprob_from_pp_outputs(
] ]
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
class DPBalanceMeta:
"""
This class will be use in scheduler and dp controller
"""
def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
init_local_tokens = [0] * self.num_workers
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
self.shared_state = self._manager.Namespace()
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
def destructor(self):
# we must destructor this class manually
self._manager.shutdown()
def get_shared_onfly(self) -> List[Dict[int, int]]:
return [dict(d) for d in self.shared_state.onfly_info]
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
self.shared_state.onfly_info = data
def get_shared_local_tokens(self) -> List[int]:
return list(self.shared_state.local_tokens)
def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data
def __getstate__(self):
state = self.__dict__.copy()
del state["_manager"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._manager = None
...@@ -120,11 +120,8 @@ message GenerateRequest { ...@@ -120,11 +120,8 @@ message GenerateRequest {
// Data parallel routing // Data parallel routing
int32 data_parallel_rank = 16; int32 data_parallel_rank = 16;
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response // Whether client wants streaming response
bool stream = 18; bool stream = 17;
} }
message TokenizedInput { message TokenizedInput {
......
...@@ -124,47 +124,5 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): ...@@ -124,47 +124,5 @@ class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
self.assertGreater(avg_spec_accept_length, 2.5) self.assertGreater(avg_spec_accept_length, 2.5)
class TestDPAttentionMinimumTokenLoadBalance(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
"--load-balance-method",
"minimum_tokens",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment