Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
61dec545
"docs/vscode:/vscode.git/clone" did not exist on "b79fffdcb5c52ba8fdc72a9f18aabc3cd50bc7ff"
Unverified
Commit
61dec545
authored
Dec 08, 2024
by
Ke Bao
Committed by
GitHub
Dec 08, 2024
Browse files
Remove unused vars in the triton backend (#2401)
parent
96db0f66
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
33 deletions
+14
-33
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-17
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+10
-10
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+0
-6
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
61dec545
...
@@ -35,11 +35,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -35,11 +35,6 @@ class TritonAttnBackend(AttentionBackend):
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
)
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
...
@@ -53,9 +48,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -53,9 +48,6 @@ class TritonAttnBackend(AttentionBackend):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
(
forward_batch
.
batch_size
,
forward_batch
.
batch_size
,
...
@@ -67,13 +59,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -67,13 +59,12 @@ class TritonAttnBackend(AttentionBackend):
device
=
self
.
device
,
device
=
self
.
device
,
)
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
max_extend_len
=
None
else
:
else
:
start_loc
=
attn_logits
=
max_seq_len
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
self
.
forward_metadata
=
attn_logits
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
...
@@ -96,9 +87,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -96,9 +87,7 @@ class TritonAttnBackend(AttentionBackend):
):
):
# NOTE: encoder_lens expected to be zeros or None
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
None
,
)
)
...
@@ -137,7 +126,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -137,7 +126,7 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
_
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -175,7 +164,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -175,7 +164,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
attn_logits
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
@@ -189,10 +178,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -189,10 +178,8 @@ class TritonAttnBackend(AttentionBackend):
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
attn_logits
,
attn_logits
,
max_seq_len
,
self
.
num_kv_splits
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
61dec545
...
@@ -19,6 +19,9 @@ It supports page size = 1.
...
@@ -19,6 +19,9 @@ It supports page size = 1.
# Adapted from
# Adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
import
logging
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -26,6 +29,13 @@ from sglang.srt.utils import is_hip
...
@@ -26,6 +29,13 @@ from sglang.srt.utils import is_hip
is_hip_
=
is_hip
()
is_hip_
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger
.
warn
(
"The following error message 'operation scheduled before its operands' can be ignored."
)
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
...
@@ -166,7 +176,6 @@ def _decode_att_m_fwd(
...
@@ -166,7 +176,6 @@ def _decode_att_m_fwd(
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd(
...
@@ -389,7 +398,6 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -556,7 +564,6 @@ def decode_attention_fwd_normal(
...
@@ -556,7 +564,6 @@ def decode_attention_fwd_normal(
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
...
@@ -569,7 +576,6 @@ def decode_attention_fwd_normal(
...
@@ -569,7 +576,6 @@ def decode_attention_fwd_normal(
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -586,7 +592,6 @@ def decode_attention_fwd_grouped(
...
@@ -586,7 +592,6 @@ def decode_attention_fwd_grouped(
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
...
@@ -599,7 +604,6 @@ def decode_attention_fwd_grouped(
...
@@ -599,7 +604,6 @@ def decode_attention_fwd_grouped(
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -614,10 +618,8 @@ def decode_attention_fwd(
...
@@ -614,10 +618,8 @@ def decode_attention_fwd(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
...
@@ -636,7 +638,6 @@ def decode_attention_fwd(
...
@@ -636,7 +638,6 @@ def decode_attention_fwd(
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -652,7 +653,6 @@ def decode_attention_fwd(
...
@@ -652,7 +653,6 @@ def decode_attention_fwd(
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
61dec545
...
@@ -196,7 +196,6 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -196,7 +196,6 @@ class TestTritonAttention(unittest.TestCase):
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
...
@@ -212,10 +211,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -212,10 +211,8 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -255,7 +252,6 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -255,7 +252,6 @@ class TestTritonAttention(unittest.TestCase):
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
...
@@ -273,7 +269,6 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -273,7 +269,6 @@ class TestTritonAttention(unittest.TestCase):
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -293,7 +288,6 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -293,7 +288,6 @@ class TestTritonAttention(unittest.TestCase):
b_req_idx
,
b_req_idx
,
b_seq_len
,
b_seq_len
,
attn_logits1
,
attn_logits1
,
seq_len
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment