Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61dec545
Unverified
Commit
61dec545
authored
Dec 08, 2024
by
Ke Bao
Committed by
GitHub
Dec 08, 2024
Browse files
Remove unused vars in the triton backend (#2401)
parent
96db0f66
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
33 deletions
+14
-33
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-17
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+10
-10
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+0
-6
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
61dec545
...
...
@@ -35,11 +35,6 @@ class TritonAttnBackend(AttentionBackend):
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
...
...
@@ -53,9 +48,6 @@ class TritonAttnBackend(AttentionBackend):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
attn_logits
=
torch
.
empty
(
(
forward_batch
.
batch_size
,
...
...
@@ -67,13 +59,12 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
self
.
forward_metadata
=
attn_logits
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
...
...
@@ -96,9 +87,7 @@ class TritonAttnBackend(AttentionBackend):
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
...
...
@@ -137,7 +126,7 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
_
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -175,7 +164,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
attn_logits
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
@@ -189,10 +178,8 @@ class TritonAttnBackend(AttentionBackend):
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
logit_cap
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
61dec545
...
...
@@ -19,6 +19,9 @@ It supports page size = 1.
# Adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
import
logging
import
triton
import
triton.language
as
tl
...
...
@@ -26,6 +29,13 @@ from sglang.srt.utils import is_hip
is_hip_
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger
.
warn
(
"The following error message 'operation scheduled before its operands' can be ignored."
)
@
triton
.
jit
def
tanh
(
x
):
...
...
@@ -166,7 +176,6 @@ def _decode_att_m_fwd(
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -556,7 +564,6 @@ def decode_attention_fwd_normal(
b_req_idx
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
...
...
@@ -569,7 +576,6 @@ def decode_attention_fwd_normal(
req_to_token
,
b_req_idx
,
b_seq_len
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -586,7 +592,6 @@ def decode_attention_fwd_grouped(
b_req_idx
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
...
...
@@ -599,7 +604,6 @@ def decode_attention_fwd_grouped(
req_to_token
,
b_req_idx
,
b_seq_len
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -614,10 +618,8 @@ def decode_attention_fwd(
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
...
...
@@ -636,7 +638,6 @@ def decode_attention_fwd(
b_req_idx
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -652,7 +653,6 @@ def decode_attention_fwd(
b_req_idx
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
61dec545
...
...
@@ -196,7 +196,6 @@ class TestTritonAttention(unittest.TestCase):
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
...
...
@@ -212,10 +211,8 @@ class TestTritonAttention(unittest.TestCase):
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
seq_len
,
num_kv_splits
,
sm_scale
,
)
...
...
@@ -255,7 +252,6 @@ class TestTritonAttention(unittest.TestCase):
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
...
...
@@ -273,7 +269,6 @@ class TestTritonAttention(unittest.TestCase):
b_req_idx
,
b_seq_len
,
attn_logits
,
seq_len
,
num_kv_splits
,
sm_scale
,
)
...
...
@@ -293,7 +288,6 @@ class TestTritonAttention(unittest.TestCase):
b_req_idx
,
b_seq_len
,
attn_logits1
,
seq_len
,
num_kv_splits
,
sm_scale
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment