Unverified Commit 60d7beda authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Add split tile size for Triton attention (#10425)

parent 2f8ba6fe
...@@ -94,6 +94,11 @@ class TritonAttnBackend(AttentionBackend): ...@@ -94,6 +94,11 @@ class TritonAttnBackend(AttentionBackend):
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
) )
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
if self.split_tile_size is not None:
self.max_kv_splits = (
self.max_context_len + self.split_tile_size - 1
) // self.split_tile_size
# Check arguments # Check arguments
assert not ( assert not (
...@@ -153,6 +158,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -153,6 +158,12 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits.fill_(self.max_kv_splits) num_kv_splits.fill_(self.max_kv_splits)
return return
if self.split_tile_size is not None:
num_kv_splits[:] = (
seq_lens + self.split_tile_size - 1
) // self.split_tile_size
return
if num_seq < 256: if num_seq < 256:
SCHEDULE_SEQ = 256 SCHEDULE_SEQ = 256
else: else:
......
...@@ -362,6 +362,7 @@ class ServerArgs: ...@@ -362,6 +362,7 @@ class ServerArgs:
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
triton_attention_num_kv_splits: int = 8 triton_attention_num_kv_splits: int = 8
triton_attention_split_tile_size: Optional[int] = None
num_continuous_decode_steps: int = 1 num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False delete_ckpt_after_loading: bool = False
enable_memory_saver: bool = False enable_memory_saver: bool = False
...@@ -2100,6 +2101,12 @@ class ServerArgs: ...@@ -2100,6 +2101,12 @@ class ServerArgs:
default=ServerArgs.triton_attention_num_kv_splits, default=ServerArgs.triton_attention_num_kv_splits,
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.", help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
) )
parser.add_argument(
"--triton-attention-split-tile-size",
type=int,
default=ServerArgs.triton_attention_split_tile_size,
help="The size of split KV tile in flash decoding Triton kernel. Used for deterministic inference.",
)
parser.add_argument( parser.add_argument(
"--num-continuous-decode-steps", "--num-continuous-decode-steps",
type=int, type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment