Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
22630ca2
Unverified
Commit
22630ca2
authored
May 30, 2025
by
Jianan Ji
Committed by
GitHub
May 30, 2025
Browse files
Support sliding window in triton backend (#6509)
parent
d279d499
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
350 additions
and
13 deletions
+350
-13
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+198
-5
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+12
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-4
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+7
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_triton_sliding_window.py
test/srt/test_triton_sliding_window.py
+132
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
22630ca2
...
...
@@ -72,6 +72,65 @@ def get_num_kv_splits_triton(
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
def
update_sliding_window_buffer
(
window_kv_indptr
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
device
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
def
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
@
dataclass
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
...
...
@@ -83,6 +142,10 @@ class ForwardMetadata:
qo_indptr
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
mask_indptr
:
torch
.
Tensor
# Sliding window
window_kv_indptr
:
torch
.
Tensor
window_kv_indices
:
torch
.
Tensor
window_num_kv_splits
:
torch
.
Tensor
class
TritonAttnBackend
(
AttentionBackend
):
...
...
@@ -109,6 +172,13 @@ class TritonAttnBackend(AttentionBackend):
max_bs
=
model_runner
.
req_to_token_pool
.
size
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
model_config
.
is_encoder_decoder
),
"Sliding window and cross attention are not supported together"
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
...
...
@@ -116,6 +186,18 @@ class TritonAttnBackend(AttentionBackend):
else
:
self
.
kv_indptr
=
kv_indptr_buf
# If sliding window is enabled, we might need two sets of buffers
# because of interleaved attention types (e.g. for Gemma3)
self
.
window_kv_indptr
=
None
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
if
kv_indptr_buf
is
None
:
self
.
window_kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
else
:
# When provided a buffer, create a clone for the second buffer
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
if
not
self
.
skip_prefill
:
...
...
@@ -191,6 +273,9 @@ class TritonAttnBackend(AttentionBackend):
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
window_kv_indptr
=
self
.
window_kv_indptr
window_kv_indices
=
None
window_num_kv_splits
=
None
spec_info
=
forward_batch
.
spec_info
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
...
...
@@ -209,6 +294,26 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
# Sliding window
if
(
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
):
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
=
(
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
seq_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
)
)
window_num_kv_splits
=
torch
.
empty
(
(
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
get_num_kv_splits
(
window_num_kv_splits
,
window_kv_lens
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
...
...
@@ -224,7 +329,6 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
)
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
)
qo_indptr
=
None
...
...
@@ -232,6 +336,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# TODO: Support sliding window in spec inference
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
0
,
...
...
@@ -303,6 +408,17 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
# Sliding window
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indptr
,
window_kv_indices
,
_
=
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
extend_prefix_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
...
...
@@ -324,6 +440,9 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
,
custom_mask
,
mask_indptr
,
window_kv_indptr
,
window_kv_indices
,
window_num_kv_splits
,
)
def
init_cuda_graph_state
(
...
...
@@ -358,6 +477,20 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
if
kv_indices_buf
is
None
:
self
.
cuda_graph_window_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
sliding_window_size
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
else
:
self
.
cuda_graph_window_kv_indices
=
torch
.
zeros_like
(
kv_indices_buf
)
self
.
cuda_graph_window_num_kv_splits
=
torch
.
full
(
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
...
...
@@ -369,6 +502,9 @@ class TritonAttnBackend(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
assert
encoder_lens
is
None
,
"Not supported"
window_kv_indptr
=
self
.
window_kv_indptr
window_kv_indices
=
None
window_num_kv_splits
=
None
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
...
...
@@ -385,6 +521,21 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
if
(
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
):
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
_
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
[:
bs
],
req_pool_indices
,
bs
,
)
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
...
...
@@ -468,6 +619,9 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
,
custom_mask
,
mask_indptr
,
window_kv_indptr
,
window_kv_indices
,
window_num_kv_splits
,
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -500,11 +654,31 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
)
num_token
=
bs
if
(
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
):
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
[:
bs
],
req_pool_indices
[:
bs
],
bs
,
)
self
.
get_num_kv_splits
(
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
)
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
...
...
@@ -582,6 +756,17 @@ class TritonAttnBackend(AttentionBackend):
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
if
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
:
sliding_window_size
=
(
layer
.
sliding_window_size
)
# Needed for sliding window mask
kv_indptr
=
self
.
forward_metadata
.
window_kv_indptr
kv_indices
=
self
.
forward_metadata
.
window_kv_indices
else
:
sliding_window_size
=
-
1
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -590,14 +775,15 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
self
.
forward_metadata
.
qo_indptr
,
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
kv_indptr
,
kv_indices
,
self
.
forward_metadata
.
custom_mask
,
causal
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
sliding_window_size
,
)
return
o
...
...
@@ -625,13 +811,20 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
if
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
:
kv_indptr
=
self
.
forward_metadata
.
window_kv_indptr
kv_indices
=
self
.
forward_metadata
.
window_kv_indices
else
:
kv_indptr
=
self
.
forward_metadata
.
kv_indptr
kv_indices
=
self
.
forward_metadata
.
kv_indices
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
self
.
forward_metadata
.
kv_indptr
,
self
.
forward_metadata
.
kv_indices
,
kv_indptr
,
kv_indices
,
self
.
forward_metadata
.
attn_logits
,
self
.
forward_metadata
.
attn_lse
,
self
.
forward_metadata
.
num_kv_splits
,
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
22630ca2
...
...
@@ -65,6 +65,7 @@ def _fwd_kernel(
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
...
@@ -163,6 +164,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
...
...
@@ -173,10 +175,14 @@ def _fwd_kernel(
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
else
:
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
...
@@ -314,6 +320,7 @@ def extend_attention_fwd(
sm_scale
=
None
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
sliding_window_size
=-
1
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
...
@@ -412,6 +419,7 @@ def extend_attention_fwd(
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
22630ca2
...
...
@@ -1025,10 +1025,6 @@ class ModelRunner:
return
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
model_config
.
is_encoder_decoder
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
...
...
python/sglang/srt/models/gemma3_causal.py
View file @
22630ca2
...
...
@@ -277,6 +277,13 @@ class Gemma3Attention(nn.Module):
k
=
k
.
permute
(
0
,
2
,
1
,
3
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
=
forward_batch
)
# Compatible with triton backend which returns [1, s, h, head_dim]
if
attn_output
.
dim
()
==
4
and
attn_output
.
shape
[
0
]
==
1
:
attn_output
=
attn_output
.
squeeze
(
0
)
attn_output
=
attn_output
.
flatten
(
-
2
,
-
1
)
# [s, h * head_dim]
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
...
...
test/srt/run_suite.py
View file @
22630ca2
...
...
@@ -78,6 +78,7 @@ suites = {
TestFile
(
"test_triton_attention_kernels.py"
,
4
),
TestFile
(
"test_triton_attention_backend.py"
,
134
),
TestFile
(
"test_triton_moe_channel_fp8_kernel.py"
,
25
),
TestFile
(
"test_triton_sliding_window.py"
,
250
),
TestFile
(
"test_update_weights_from_disk.py"
,
114
),
TestFile
(
"test_update_weights_from_tensor.py"
,
48
),
TestFile
(
"test_vertex_endpoint.py"
,
31
),
...
...
test/srt/test_triton_sliding_window.py
0 → 100644
View file @
22630ca2
import
time
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestSlidingWindowAttentionTriton
(
CustomTestCase
):
"""Test sliding window attention functionality with triton backend."""
@
classmethod
def
setUpClass
(
cls
):
"""Set up the test server with Gemma3 model and triton backend."""
# Gemma3 model supports sliding window attention
cls
.
model
=
"google/gemma-3-4b-it"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
common_args
=
[
"--trust-remote-code"
,
"--attention-backend"
,
"triton"
,
"--context-length"
,
"8192"
,
"--random-seed"
,
"42"
,
]
cls
.
short_context_prompt
=
"The capital of France is"
# Test prompt longer than window size
cls
.
long_context_prompt
=
(
"""
Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
"""
*
100
)
cls
.
long_context_prompt
+=
"
\n
Now, summarize the story in one sentence:"
@
classmethod
def
tearDownClass
(
cls
):
pass
def
_test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
print
(
f
"MMLU metrics with sliding window:
{
metrics
}
"
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.64
)
def
_test_short_context_generation
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
self
.
short_context_prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
256
,
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
result
=
response
.
json
()
self
.
assertIn
(
"paris"
,
result
[
"text"
].
lower
())
print
(
f
"Short context generation result:
{
result
[
'text'
]
}
"
)
def
_test_long_context_generation
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
self
.
long_context_prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
256
,
},
},
)
self
.
assertEqual
(
response
.
status_code
,
200
)
result
=
response
.
json
()
self
.
assertGreater
(
len
(
result
[
"text"
].
strip
()),
0
)
print
(
f
"Long context generation result:
{
result
[
'text'
][:
100
]
}
..."
)
def
test_no_cuda_graph
(
self
):
self
.
no_cuda_graph_process
=
popen_launch_server
(
self
.
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
self
.
common_args
+
[
"--disable-cuda-graph"
],
)
self
.
_test_short_context_generation
()
self
.
_test_long_context_generation
()
self
.
_test_mmlu
()
kill_process_tree
(
self
.
no_cuda_graph_process
.
pid
)
time
.
sleep
(
5
)
def
test_cuda_graph
(
self
):
self
.
cuda_graph_process
=
popen_launch_server
(
self
.
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
self
.
common_args
,
)
self
.
_test_short_context_generation
()
self
.
_test_long_context_generation
()
self
.
_test_mmlu
()
kill_process_tree
(
self
.
cuda_graph_process
.
pid
)
time
.
sleep
(
5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment