Unverified Commit 54479d6f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix grammar backend for tensor parallelism (#2020)

parent ba069a24
...@@ -13,90 +13,60 @@ See the License for the specific language governing permissions and ...@@ -13,90 +13,60 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Base cache class for constrained decoding tools.""" """The baseclass of backends for grammar-guided constrained decoding."""
import time from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from threading import Event, Lock from threading import Event, Lock
from typing import Any, Dict, Tuple from typing import Any, Optional, Tuple
@dataclass @dataclass
class MapEntry: class CacheEntry:
event: Event
value: Any value: Any
event: Event
def __iter__(self):
return iter((self.event, self.value))
class BaseToolCache: class BaseGrammarObject:
pass
def __init__(self, enable=True):
self.enable: bool = enable
self.cache: Dict[str, MapEntry] = {}
self.metrics: Dict[str, Any] = {}
self.lock_cache: Lock = Lock()
self.lock_metrics: Lock = Lock()
self.reset()
def reset(self): class BaseGrammarBackend:
with self.lock_cache: def __init__(self):
self.executor = ThreadPoolExecutor()
self.cache = {} self.cache = {}
with self.lock_metrics: self.cache_lock = Lock()
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
def _init_with_timer(self, key) -> Tuple[Any, float]: with self.cache_lock:
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
return val, init_time
def update_time(self, init_time):
with self.lock_metrics:
curr_total = self.metrics["total"]
new_total = curr_total + 1
# Update average init time without old_avg * old_total to avoid overflow.
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
def query(self, key):
if not self.enable:
value, init_time = self._init_with_timer(key)
self.update_time(init_time)
return value
with self.lock_cache:
if key in self.cache: if key in self.cache:
entry = self.cache[key]
cache_hit = True cache_hit = True
entry = self.cache[key]
else: else:
entry = MapEntry(Event(), None)
self.cache[key] = entry
cache_hit = False cache_hit = False
entry = CacheEntry(None, Event())
with self.lock_metrics: self.cache[key] = entry
self.metrics["total"] += 1
if cache_hit:
self.metrics["hit"] += 1
if cache_hit: if cache_hit:
entry.event.wait() entry.event.wait()
else: else:
entry.value, init_time = self._init_with_timer(key) entry.value = self.init_value_impl(key)
self.update_time(init_time)
entry.event.set() entry.event.set()
return entry.value return entry.value.copy()
def init_value(self, key): def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
raise NotImplementedError() raise NotImplementedError()
def get_cache_hit_rate(self): def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.lock_metrics: with self.cache_lock:
return self.metrics["hit"] / max(self.metrics["total"], 1) entry = self.cache.get(key)
if not entry or not entry.event.is_set():
return None
return self.cache[key].value.copy()
def get_future_value(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self.init_value, key)
def get_avg_init_time(self): def reset(self):
with self.lock_metrics: with self.cache_lock:
return self.metrics["avg_init_time"] self.cache.clear()
...@@ -17,20 +17,17 @@ limitations under the License. ...@@ -17,20 +17,17 @@ limitations under the License.
import json import json
import logging import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
from sglang.srt.constrained.base_tool_cache import BaseToolCache from sglang.srt.constrained.base_grammar_backend import (
from sglang.srt.constrained.outlines_jump_forward import ( BaseGrammarBackend,
OutlinesJumpForwardCache, BaseGrammarObject,
OutlinesJumpForwardMap,
) )
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,6 +38,7 @@ except ImportError: ...@@ -41,6 +38,7 @@ except ImportError:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema, # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input. # which only accepts string schema as input.
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
def build_regex_from_object( def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
...@@ -54,16 +52,15 @@ except ImportError: ...@@ -54,16 +52,15 @@ except ImportError:
return build_regex_from_schema(schema, whitespace_pattern) return build_regex_from_schema(schema, whitespace_pattern)
class OutlinesGrammar: class OutlinesGrammar(BaseGrammarObject):
def __init__( def __init__(
self, self,
guide: RegexGuide, guide: RegexGuide,
state: int,
jump_forward_map: Union[OutlinesJumpForwardMap, None], jump_forward_map: Union[OutlinesJumpForwardMap, None],
) -> None: ) -> None:
self.guide = guide self.guide = guide
self.state = state
self.jump_forward_map = jump_forward_map self.jump_forward_map = jump_forward_map
self.state = 0
def accept_token(self, token: int): def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token) self.state = self.guide.get_next_state(self.state, token)
...@@ -105,46 +102,18 @@ class OutlinesGrammar: ...@@ -105,46 +102,18 @@ class OutlinesGrammar:
vocab_mask.fill_(1) vocab_mask.fill_(1)
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
class OutlinesGrammarBackend:
def __init__(
self,
tokenizer,
whitespace_patterns: bool,
allow_jump_forward: bool,
):
self.executor = ThreadPoolExecutor()
self.grammar_cache = OutlinesCache(
tokenizer,
whitespace_pattern=whitespace_patterns,
)
self.jump_forward_cache = (
OutlinesJumpForwardCache() if allow_jump_forward else None
)
def _query(self, key: Tuple[str, str]) -> OutlinesGrammar:
guide, regex = self.grammar_cache.query(key)
jump_forward_map = (
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
)
return OutlinesGrammar(guide, 0, jump_forward_map)
def query(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self._query, key)
def reset(self):
self.grammar_cache.reset()
if self.jump_forward_cache:
self.jump_forward_cache.reset()
class OutlinesGrammarBackend(BaseGrammarBackend):
class OutlinesCache(BaseToolCache):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
whitespace_pattern=None, whitespace_pattern: bool,
allow_jump_forward: bool,
): ):
super().__init__(enable=True) super().__init__()
try: try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer) self.outlines_tokenizer = TransformerTokenizer(tokenizer)
...@@ -167,9 +136,10 @@ class OutlinesCache(BaseToolCache): ...@@ -167,9 +136,10 @@ class OutlinesCache(BaseToolCache):
self.outlines_tokenizer.vocabulary = ( self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab() self.outlines_tokenizer.tokenizer.get_vocab()
) )
self.allow_jump_forward = allow_jump_forward
self.whitespace_pattern = whitespace_pattern self.whitespace_pattern = whitespace_pattern
def init_value(self, key): def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
try: try:
...@@ -186,18 +156,10 @@ class OutlinesCache(BaseToolCache): ...@@ -186,18 +156,10 @@ class OutlinesCache(BaseToolCache):
regex = key_string regex = key_string
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
try:
parse_pattern(regex) guide = RegexGuide(regex, self.outlines_tokenizer)
except InvalidSyntax as e: if self.allow_jump_forward:
logger.warning(f"skip invalid regex guide: {regex=}, {e=}") jump_forward_map = OutlinesJumpForwardMap(regex)
return None, regex else:
jump_forward_map = None
ret = RegexGuide(regex, self.outlines_tokenizer), regex return OutlinesGrammar(guide, jump_forward_map)
return ret
def _query(self, key: Tuple[str, str]):
guide, regex = self.grammar_cache.query(key)
jump_forward_map = (
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
)
return OutlinesGrammar(guide, 0, jump_forward_map)
...@@ -27,8 +27,6 @@ from interegular import InvalidSyntax ...@@ -27,8 +27,6 @@ from interegular import InvalidSyntax
from outlines.caching import cache as disk_cache from outlines.caching import cache as disk_cache
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from sglang.srt.constrained.base_tool_cache import BaseToolCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,20 +40,15 @@ class JumpEdge: ...@@ -42,20 +40,15 @@ class JumpEdge:
byte_next_state: int = None byte_next_state: int = None
class OutlinesJumpForwardMap: @disk_cache()
def __init__(self, regex_string): def init_state_to_jump_forward(regex_string):
@disk_cache()
def _init_state_to_jump_forward(regex_string):
try: try:
regex_pattern = interegular.parse_pattern(regex_string) regex_pattern = interegular.parse_pattern(regex_string)
except InvalidSyntax as e: except InvalidSyntax as e:
logger.warning(f"skip invalid regex: {regex_string}, {e=}") logger.warning(f"skip invalid regex: {regex_string}, {e=}")
self.state_to_jump_forward = None
return return
byte_fsm = make_byte_level_fsm( byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm) regex_fsm, _ = make_deterministic_fsm(byte_fsm)
fsm_info: FSMInfo = regex_fsm.fsm_info fsm_info: FSMInfo = regex_fsm.fsm_info
...@@ -127,7 +120,10 @@ class OutlinesJumpForwardMap: ...@@ -127,7 +120,10 @@ class OutlinesJumpForwardMap:
return state_to_jump_forward return state_to_jump_forward
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
class OutlinesJumpForwardMap:
def __init__(self, regex_string):
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
def jump_forward_symbol(self, state): def jump_forward_symbol(self, state):
jump_forward_str = "" jump_forward_str = ""
...@@ -164,18 +160,6 @@ class OutlinesJumpForwardMap: ...@@ -164,18 +160,6 @@ class OutlinesJumpForwardMap:
) )
class OutlinesJumpForwardCache(BaseToolCache):
def __init__(self):
super().__init__()
def init_value(self, regex):
forward_map = OutlinesJumpForwardMap(regex)
if forward_map.state_to_jump_forward:
return forward_map
else:
return None
def test_main(regex_string): def test_main(regex_string):
jump_forward_map = OutlinesJumpForwardMap(regex_string) jump_forward_map = OutlinesJumpForwardMap(regex_string)
for state, e in jump_forward_map.state_to_jump_forward.items(): for state, e in jump_forward_map.state_to_jump_forward.items():
......
...@@ -15,38 +15,36 @@ limitations under the License. ...@@ -15,38 +15,36 @@ limitations under the License.
"""Constrained decoding with xgrammar backend.""" """Constrained decoding with xgrammar backend."""
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
try: from sglang.srt.constrained.base_grammar_backend import (
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher BaseGrammarBackend,
BaseGrammarObject,
import_error = None )
except ImportError as e:
import_error = e
class Dummy:
pass
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
MAX_ROLLBACK_TOKENS = 10 MAX_ROLLBACK_TOKENS = 10
class XGrammarGrammar: class XGrammarGrammar(BaseGrammarObject):
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None: def __init__(
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
) -> None:
self.matcher = matcher self.matcher = matcher
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ctx = ctx
def accept_token(self, token: int): def accept_token(self, token: int):
assert self.matcher.accept_token(token) assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
return [], self.matcher.find_jump_forward_string() s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper _, data = helper
...@@ -77,51 +75,40 @@ class XGrammarGrammar: ...@@ -77,51 +75,40 @@ class XGrammarGrammar:
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
] = 1 ] = 1
def copy(self):
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size,
)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
class XGrammarGrammarBackend: class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
vocab_size: int, vocab_size: int,
): ):
if import_error: super().__init__()
raise import_error
self.executor = ThreadPoolExecutor()
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
self.vocab_size = vocab_size
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
def query(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self._query, key)
def reset(self):
self.grammar_cache.reset()
class XGrammarCache:
def __init__(self, tokenizer, vocab_size: int):
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
self.vocab_size = vocab_size self.vocab_size = vocab_size
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar: def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string) ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex": elif key_type == "regex":
raise ValueError("regex hasn't been supported by xgrammar yet") raise ValueError("regex hasn't been supported by xgrammar yet")
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
def query(self, key: Tuple[str, str]) -> GrammarMatcher: matcher = GrammarMatcher(
ctx = self.get_context(key)
return GrammarMatcher(
ctx, ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS, max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size, mask_vocab_size=self.vocab_size,
) )
return XGrammarGrammar(matcher, self.vocab_size, ctx)
def reset(self): def reset(self):
self.grammar_cache.clear() self.grammar_cache.clear()
...@@ -37,6 +37,7 @@ import torch ...@@ -37,6 +37,7 @@ import torch
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
...@@ -248,7 +249,7 @@ class Req: ...@@ -248,7 +249,7 @@ class Req:
self.embedding = None self.embedding = None
# Constrained decoding # Constrained decoding
self.grammar = None self.grammar: Optional[BaseGrammarObject] = None
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
......
...@@ -244,7 +244,7 @@ class Scheduler: ...@@ -244,7 +244,7 @@ class Scheduler:
self.grammar_backend = OutlinesGrammarBackend( self.grammar_backend = OutlinesGrammarBackend(
self.tokenizer, self.tokenizer,
whitespace_patterns=server_args.constrained_json_whitespace_pattern, whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward, allow_jump_forward=not server_args.disable_jump_forward,
) )
elif server_args.grammar_backend == "xgrammar": elif server_args.grammar_backend == "xgrammar":
...@@ -467,21 +467,6 @@ class Scheduler: ...@@ -467,21 +467,6 @@ class Scheduler:
# By default, only return the logprobs for output tokens # By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1 req.logprob_start_len = len(recv_req.input_ids) - 1
# Init grammar cache for this request
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
req.grammar = self.grammar_backend.query(
("json", req.sampling_params.json_schema),
)
elif req.sampling_params.regex is not None:
req.grammar = self.grammar_backend.query(
("regex", req.sampling_params.regex)
)
# Truncate prompts that are too long # Truncate prompts that are too long
if len(req.origin_input_ids) > self.max_req_input_len: if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning( logger.warning(
...@@ -499,7 +484,24 @@ class Scheduler: ...@@ -499,7 +484,24 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1, self.max_req_len - len(req.origin_input_ids) - 1,
) )
if req.grammar is not None: # Init grammar cache for this request
add_to_grammar_queue = False
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex is not None:
key = ("regex", req.sampling_params.regex)
req.grammar = self.grammar_backend.get_cached_value(key)
if not req.grammar:
req.grammar = self.grammar_backend.get_future_value(key)
add_to_grammar_queue = True
if add_to_grammar_queue:
self.grammar_queue.append(req) self.grammar_queue.append(req)
else: else:
self.waiting_queue.append(req) self.waiting_queue.append(req)
...@@ -650,14 +652,7 @@ class Scheduler: ...@@ -650,14 +652,7 @@ class Scheduler:
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue # Check if the grammar is ready in the grammar queue
if self.grammar_queue: if self.grammar_queue:
new_grammar_queue = [] self.move_ready_grammar_requests()
for req in self.grammar_queue:
try:
req.grammar = req.grammar.result(timeout=0.05)
self.waiting_queue.append(req)
except futures._base.TimeoutError:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
...@@ -1145,6 +1140,30 @@ class Scheduler: ...@@ -1145,6 +1140,30 @@ class Scheduler:
) )
) )
def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
for req in self.grammar_queue:
try:
req.grammar = req.grammar.result(timeout=0.05)
num_ready_reqs += 1
except futures._base.TimeoutError:
break
if self.tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
)
num_ready_reqs_max = tensor.item()
for i in range(num_ready_reqs, num_ready_reqs_max):
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
num_ready_reqs = num_ready_reqs_max
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def flush_cache(self): def flush_cache(self):
"""Flush the memory pool and cache.""" """Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and ( if len(self.waiting_queue) == 0 and (
...@@ -1152,9 +1171,8 @@ class Scheduler: ...@@ -1152,9 +1171,8 @@ class Scheduler:
): ):
self.tree_cache.reset() self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_backend is not None: if self.grammar_backend:
self.grammar_backend.reset() self.grammar_backend.reset()
# TODO(dark): reset the bnf cache
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool.clear() self.token_to_kv_pool.clear()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment