Unverified Commit db0cc57e authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support decode retract and update decode.py (#7196)

parent 349bb2c9
......@@ -31,7 +31,7 @@ import numpy as np
import torch
from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
......@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import (
KVCache,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -145,7 +153,11 @@ class DecodePreallocQueue:
gloo_group: ProcessGroup,
tp_rank: int,
tp_size: int,
dp_size: int,
gpu_id: int,
bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
transfer_backend: TransferBackend,
):
self.req_to_token_pool = req_to_token_pool
......@@ -161,25 +173,35 @@ class DecodePreallocQueue:
self.gloo_group = gloo_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.dp_size = dp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.num_reserved_decode_tokens = int(
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
)
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.transfer_backend = transfer_backend
self.retracted_queue: List[Req] = []
self.prefill_pp_size = prefill_pp_size
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
attn_tp_size = self.tp_size // self.dp_size
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
kv_args.decode_tp_size = attn_tp_size
kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
......@@ -194,6 +216,7 @@ class DecodePreallocQueue:
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
......@@ -205,27 +228,83 @@ class DecodePreallocQueue:
)
return kv_manager
def add(self, req: Req) -> None:
def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
# Fake transfer for warmup reqs
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
if self._check_if_req_exceed_kv_capacity(req):
return
if is_retracted:
self.retracted_queue.append(req)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
kv_receiver_class = get_kv_class(
TransferBackend.FAKE, KVClassType.RECEIVER
)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
def extend(self, reqs: List[Req]) -> None:
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
for req in reqs:
self.add(req)
self.add(req, is_retracted=is_retracted)
def resume_retracted_reqs(self) -> List[Req]:
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
# allocate memory
resumed_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
for i, req in enumerate(self.retracted_queue):
if self.req_to_token_pool.available_size() <= 0:
break
required_tokens_for_request = (
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
)
if required_tokens_for_request > allocatable_tokens:
break
resumed_reqs.append(req)
indices_to_remove.add(i)
req.is_retracted = False
self._pre_alloc(req)
allocatable_tokens -= required_tokens_for_request
# load from cpu, release the cpu copy
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
self.retracted_queue = [
entry
for i, entry in enumerate(self.retracted_queue)
if i not in indices_to_remove
]
return resumed_reqs
def _update_handshake_waiters(self) -> None:
if not self.queue:
......@@ -255,6 +334,8 @@ class DecodePreallocQueue:
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
else:
raise ValueError(f"Unexpected poll case: {poll}")
def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
......@@ -262,8 +343,16 @@ class DecodePreallocQueue:
preallocated_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens()
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
retractable_tokens = sum(
len(r.origin_input_ids) + len(r.output_ids)
for r in self.scheduler.running_batch.reqs
)
allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
......@@ -272,6 +361,7 @@ class DecodePreallocQueue:
)
indices_to_remove.add(i)
# Then, preallocate the remaining requests if possible
for i, decode_req in enumerate(self.queue):
if i in indices_to_remove:
continue
......@@ -285,10 +375,23 @@ class DecodePreallocQueue:
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break
# Memory estimation: don't add if the projected memory cannot be met
# TODO: add new_token ratio
origin_input_len = len(decode_req.req.origin_input_ids)
required_tokens_for_request = (
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
origin_input_len + self.num_reserved_decode_tokens
)
if (
max(
required_tokens_for_request,
origin_input_len
+ decode_req.req.sampling_params.max_new_tokens
- retractable_tokens,
)
> allocatable_tokens
):
break
if required_tokens_for_request > allocatable_tokens:
break
......@@ -321,15 +424,35 @@ class DecodePreallocQueue:
return preallocated_reqs
def _allocatable_tokens(self) -> int:
allocatable_tokens = (
self.token_to_kv_pool_allocator.available_size()
- self.num_reserved_decode_tokens
def _allocatable_tokens(
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
) -> int:
need_space_for_single_req = (
max(
[
x.sampling_params.max_new_tokens
+ len(x.origin_input_ids)
- retractable_tokens
for x in self.scheduler.running_batch.reqs
]
)
if retractable_tokens is not None
and len(self.scheduler.running_batch.reqs) > 0
else 0
)
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* (
len(self.scheduler.running_batch.reqs)
+ len(self.transfer_queue.queue)
+ len(self.scheduler.waiting_queue)
)
),
# make sure each request can finish if reach max_tokens with all other requests retracted
need_space_for_single_req,
)
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
......@@ -342,15 +465,27 @@ class DecodePreallocQueue:
self.scheduler.last_batch.reqs
)
if count_retracted:
allocatable_tokens -= sum(
[
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
for req in self.retracted_queue
]
)
return allocatable_tokens
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
assert req_pool_indices is not None
assert (
req_pool_indices is not None
), "req_pool_indices is full! There is a bug in memory estimation."
req.req_pool_idx = req_pool_indices[0]
if self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
......@@ -375,7 +510,10 @@ class DecodePreallocQueue:
),
extend_num_tokens=num_tokens,
)
assert kv_loc is not None
assert (
kv_loc is not None
), "KV cache is full! There is a bug in memory estimation."
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
......@@ -395,6 +533,7 @@ class DecodeTransferQueue:
self,
gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
tp_rank: int,
metadata_buffers: MetadataBuffers,
scheduler: Scheduler,
tree_cache: BasePrefixCache,
......@@ -402,6 +541,7 @@ class DecodeTransferQueue:
self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.metadata_buffers = metadata_buffers
self.scheduler = scheduler
self.tree_cache = tree_cache
......@@ -412,10 +552,9 @@ class DecodeTransferQueue:
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs)
def pop_transferred(self) -> List[DecodeRequest]:
def pop_transferred(self) -> List[Req]:
if not self.queue:
return []
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
......@@ -424,7 +563,7 @@ class DecodeTransferQueue:
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
......@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
......@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
......@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
return new_batch
def process_decode_queue(self: Scheduler):
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
self.waiting_queue.extend(resumed_reqs)
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
# if there are still retracted requests, we do not allocate new requests
return
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
......
......@@ -25,6 +25,7 @@ from collections import deque
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
......@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
.astype(np.int64)
)
req.start_send_idx = end_idx
if last_chunk:
......
......@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req = self.reqs[idx]
retracted_reqs.append(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
......@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.reset_for_retract()
if len(retracted_reqs) == 0:
# Corner case: only one request left
raise ValueError(
"Failed to retract any request. No space left for only one request."
)
self.filter_batch(keep_indices=sorted_indices)
# Reqs in batch are filtered
......
......@@ -628,6 +628,7 @@ class Scheduler(
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
tp_rank=self.tp_rank,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
tree_cache=self.tree_cache,
......@@ -650,7 +651,11 @@ class Scheduler(
gloo_group=self.attn_tp_cpu_group,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
transfer_backend=self.transfer_backend,
)
......@@ -1124,14 +1129,14 @@ class Scheduler(
else:
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req]):
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs)
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else:
self.waiting_queue.extend(reqs)
......@@ -1274,6 +1279,7 @@ class Scheduler(
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
......@@ -1575,7 +1581,7 @@ class Scheduler(
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self._extend_requests_to_queue(retracted_reqs)
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
......
......@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
self.is_not_in_free_group = True
self.free_group = []
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class MHATokenToKVPool(KVCache):
......@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
self.head_dim = head_dim
self._create_buffers()
# used for chunked cpu-offloading
self.chunk_size = 8192
self.layer_transfer_counter = None
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None
......@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append([k_cpu, v_cpu])
torch.cuda.synchronize()
return kv_cache_cpu
def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
for layer_id in range(self.layer_num):
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
k_cpu, v_cpu = (
kv_cache_cpu[layer_id][i // self.chunk_size][0],
kv_cache_cpu[layer_id][i // self.chunk_size][1],
)
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
self.k_buffer[layer_id][chunk_indices] = k_chunk
self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize()
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
......
......@@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.20)
class TestDisaggregationSimulatedRetract(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "true"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.62)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment