Unverified Commit 326df4ba authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use a single workspace for flashinfer (#1077)

parent 6767e222
...@@ -64,7 +64,7 @@ def main(args): ...@@ -64,7 +64,7 @@ def main(args):
@sgl.function @sgl.function
def few_shot_gsm8k(s, question): def few_shot_gsm8k(s, question):
s += few_shot_examples + question s += few_shot_examples + question
s += sgl.gen("answer", max_tokens=512, stop="Question") s += sgl.gen("answer", max_tokens=512, stop=["Question", "Assistant:"])
##################################### #####################################
########## SGL Program End ########## ########## SGL Program End ##########
......
...@@ -27,7 +27,7 @@ class GlobalConfig: ...@@ -27,7 +27,7 @@ class GlobalConfig:
# Runtime constants: others # Runtime constants: others
self.num_continue_decode_steps = 10 self.num_continue_decode_steps = 10
self.retract_decode_steps = 20 self.retract_decode_steps = 20
self.flashinfer_workspace_size = 192 * 1024 * 1024 self.flashinfer_workspace_size = 384 * 1024 * 1024
# Output tokenization configs # Output tokenization configs
self.skip_special_tokens_in_output = True self.skip_special_tokens_in_output = True
......
...@@ -120,13 +120,13 @@ class CudaGraphRunner: ...@@ -120,13 +120,13 @@ class CudaGraphRunner:
) )
if model_runner.sliding_window_size is None: if model_runner.sliding_window_size is None:
self.flashinfer_workspace_buffer = ( self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0] self.model_runner.flashinfer_workspace_buffer
) )
else: else:
self.flashinfer_workspace_buffers = [ self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0], self.model_runner.flashinfer_workspace_buffer
self.model_runner.flashinfer_workspace_buffers[2], )
]
self.flashinfer_kv_indptr = [ self.flashinfer_kv_indptr = [
self.flashinfer_kv_indptr, self.flashinfer_kv_indptr,
self.flashinfer_kv_indptr.clone(), self.flashinfer_kv_indptr.clone(),
...@@ -200,7 +200,7 @@ class CudaGraphRunner: ...@@ -200,7 +200,7 @@ class CudaGraphRunner:
for i in range(2): for i in range(2):
flashinfer_decode_wrapper.append( flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[i], self.flashinfer_workspace_buffer,
"NHD", "NHD",
use_cuda_graph=True, use_cuda_graph=True,
use_tensor_cores=use_tensor_cores, use_tensor_cores=use_tensor_cores,
......
...@@ -318,28 +318,26 @@ class ModelRunner: ...@@ -318,28 +318,26 @@ class ModelRunner:
use_tensor_cores = False use_tensor_cores = False
if self.sliding_window_size is None: if self.sliding_window_size is None:
self.flashinfer_workspace_buffers = torch.empty( self.flashinfer_workspace_buffer = torch.empty(
2,
global_config.flashinfer_workspace_size, global_config.flashinfer_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
device="cuda", device="cuda",
) )
self.flashinfer_prefill_wrapper_ragged = ( self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], "NHD" self.flashinfer_workspace_buffer, "NHD"
) )
) )
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[1], "NHD" self.flashinfer_workspace_buffer, "NHD"
) )
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], self.flashinfer_workspace_buffer,
"NHD", "NHD",
use_tensor_cores=use_tensor_cores, use_tensor_cores=use_tensor_cores,
) )
else: else:
self.flashinfer_workspace_buffers = torch.empty( self.flashinfer_workspace_buffers = torch.empty(
4,
global_config.flashinfer_workspace_size, global_config.flashinfer_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
device="cuda", device="cuda",
...@@ -350,17 +348,17 @@ class ModelRunner: ...@@ -350,17 +348,17 @@ class ModelRunner:
for i in range(2): for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append( self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 0], "NHD" self.flashinfer_workspace_buffer, "NHD"
) )
) )
self.flashinfer_prefill_wrapper_paged.append( self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper( BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 1], "NHD" self.flashinfer_workspace_buffer, "NHD"
) )
) )
self.flashinfer_decode_wrapper.append( self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 0], self.flashinfer_workspace_buffer,
"NHD", "NHD",
use_tensor_cores=use_tensor_cores, use_tensor_cores=use_tensor_cores,
) )
......
...@@ -381,7 +381,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -381,7 +381,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.4", "0.1.5",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment