Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62c516ac
Unverified
Commit
62c516ac
authored
Dec 01, 2024
by
Qun Yang
Committed by
GitHub
Dec 01, 2024
Browse files
Add a simple torch native attention backend (#2241)
parent
fc78640e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
388 additions
and
26 deletions
+388
-26
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+285
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+18
-14
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+9
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+14
-8
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_torch_native_attention_backend.py
test/srt/test_torch_native_attention_backend.py
+58
-0
No files found.
python/sglang/srt/layers/attention/torch_native_backend.py
0 → 100644
View file @
62c516ac
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
TorchNativeAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
self
.
forward_metadata
=
None
self
.
device
=
model_runner
.
device
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
pass
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
_run_sdpa_forward_extend
(
self
,
query
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
extend_prefix_lens
:
torch
.
Tensor
,
extend_seq_lens
:
torch
.
Tensor
,
scaling
=
None
,
enable_gqa
=
False
,
causal
=
False
,
):
"""Run the extend forward by using torch native sdpa op.
Args:
query: [num_tokens, num_heads, head_size]
output: [num_tokens, num_heads, head_size]
k_cache: [max_total_num_tokens, num_heads, head_size]
v_cache: [max_total_num_tokens, num_heads, head_size]
req_to_token: [max_num_reqs, max_context_len]
req_pool_indices: [num_seqs]
seq_lens: [num_seqs]
extend_prefix_lens: [num_seqs]
extend_seq_lens: [num_seqs]
scaling: float or None
enable_gqa: bool
causal: bool
Returns:
output: [num_tokens, num_heads, head_size]
"""
assert
seq_lens
.
shape
[
0
]
==
extend_prefix_lens
.
shape
[
0
]
assert
seq_lens
.
shape
[
0
]
==
extend_seq_lens
.
shape
[
0
]
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
start_q
,
start_kv
=
0
,
0
for
seq_idx
in
range
(
seq_lens
.
shape
[
0
]):
# TODO: this loop process a sequence per iter, this is inefficient.
# Need optimize the performance later.
extend_seq_len_q
=
extend_seq_lens
[
seq_idx
]
prefill_seq_len_q
=
extend_prefix_lens
[
seq_idx
]
seq_len_kv
=
seq_lens
[
seq_idx
]
end_q
=
start_q
+
extend_seq_len_q
end_kv
=
start_kv
+
seq_len_kv
per_req_query
=
query
[:,
start_q
:
end_q
,
:]
per_req_query_redudant
=
torch
.
empty
(
(
per_req_query
.
shape
[
0
],
seq_len_kv
,
per_req_query
.
shape
[
2
]),
dtype
=
per_req_query
.
dtype
,
device
=
per_req_query
.
device
,
)
per_req_query_redudant
[:,
prefill_seq_len_q
:,
:]
=
per_req_query
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx
=
req_pool_indices
[
seq_idx
]
per_req_tokens
=
req_to_token
[
req_pool_idx
,
:
seq_len_kv
]
per_req_key
=
k_cache
[
per_req_tokens
].
movedim
(
0
,
query
.
dim
()
-
2
)
per_req_value
=
v_cache
[
per_req_tokens
].
movedim
(
0
,
query
.
dim
()
-
2
)
per_req_out_redudant
=
(
scaled_dot_product_attention
(
per_req_query_redudant
.
unsqueeze
(
0
),
per_req_key
.
unsqueeze
(
0
),
per_req_value
.
unsqueeze
(
0
),
enable_gqa
=
enable_gqa
,
scale
=
scaling
,
is_causal
=
causal
,
)
.
squeeze
(
0
)
.
movedim
(
query
.
dim
()
-
2
,
0
)
)
output
[
start_q
:
end_q
,
:,
:]
=
per_req_out_redudant
[
prefill_seq_len_q
:,
:,
:]
start_q
,
start_kv
=
end_q
,
end_kv
return
output
def
_run_sdpa_forward_decode
(
self
,
query
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
scaling
=
None
,
enable_gqa
=
False
,
causal
=
False
,
):
"""Run the decode forward by using torch native sdpa op.
Args:
query: [num_tokens, num_heads, head_size]
output: [num_tokens, num_heads, head_size]
k_cache: [max_total_num_tokens, num_heads, head_size]
v_cache: [max_total_num_tokens, num_heads, head_size]
req_to_token: [max_num_reqs, max_context_len]
req_pool_indices: [num_seqs]
seq_lens: [num_seqs]
scaling: float or None
enable_gqa: bool
causal: bool
Returns:
output: [num_tokens, num_heads, head_size]
"""
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
start_q
,
start_kv
=
0
,
0
for
seq_idx
in
range
(
seq_lens
.
shape
[
0
]):
# TODO: this loop process a sequence per iter, this is inefficient.
# Need optimize the performance later.
seq_len_q
=
1
seq_len_kv
=
seq_lens
[
seq_idx
]
end_q
=
start_q
+
seq_len_q
end_kv
=
start_kv
+
seq_len_kv
per_req_query
=
query
[:,
start_q
:
end_q
,
:]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx
=
req_pool_indices
[
seq_idx
]
per_req_tokens
=
req_to_token
[
req_pool_idx
,
:
seq_len_kv
]
per_req_key
=
k_cache
[
per_req_tokens
].
movedim
(
0
,
query
.
dim
()
-
2
)
per_req_value
=
v_cache
[
per_req_tokens
].
movedim
(
0
,
query
.
dim
()
-
2
)
per_req_out
=
(
scaled_dot_product_attention
(
per_req_query
.
unsqueeze
(
0
),
per_req_key
.
unsqueeze
(
0
),
per_req_value
.
unsqueeze
(
0
),
enable_gqa
=
enable_gqa
,
scale
=
scaling
,
is_causal
=
causal
,
)
.
squeeze
(
0
)
.
movedim
(
query
.
dim
()
-
2
,
0
)
)
output
[
start_q
:
end_q
,
:,
:]
=
per_req_out
start_q
,
start_kv
=
end_q
,
end_kv
return
output
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
o_
=
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
self
.
_run_sdpa_forward_extend
(
q_
,
o_
,
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_prefix_lens
,
forward_batch
.
extend_seq_lens
,
scaling
=
layer
.
scaling
,
enable_gqa
=
use_gqa
,
causal
=
not
layer
.
is_cross_attention
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
o_
=
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
self
.
_run_sdpa_forward_decode
(
q_
,
o_
,
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
scaling
=
layer
.
scaling
,
enable_gqa
=
use_gqa
,
causal
=
False
,
)
return
o
python/sglang/srt/managers/schedule_batch.py
View file @
62c516ac
...
...
@@ -743,20 +743,24 @@ class ScheduleBatch:
extend_lens
=
torch
.
tensor
(
self
.
extend_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
write_req_to_token_pool_triton
[(
bs
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
pre_lens
,
self
.
seq_lens
,
extend_lens
,
self
.
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
# The triton kernel is equivalent to the following python code.
# self.req_to_token_pool.write(
# (req.req_pool_idx, slice(pre_len, seq_len)),
# out_cache_loc[pt : pt + req.extend_input_len],
# )
if
global_server_args_dict
[
"attention_backend"
]
!=
"torch_native"
:
write_req_to_token_pool_triton
[(
bs
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
pre_lens
,
self
.
seq_lens
,
extend_lens
,
self
.
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
else
:
pt
=
0
for
i
in
range
(
bs
):
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
[
i
],
slice
(
pre_lens
[
i
],
self
.
seq_lens
[
i
])),
self
.
out_cache_loc
[
pt
:
pt
+
self
.
extend_lens
[
i
]],
)
pt
+=
self
.
extend_lens
[
i
]
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
if
self
.
model_config
.
is_encoder_decoder
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
62c516ac
...
...
@@ -256,10 +256,15 @@ class ForwardBatch:
ret
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
extend_prefix_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
ret
.
extend_num_tokens
)
if
model_runner
.
server_args
.
attention_backend
!=
"torch_native"
:
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
ret
.
extend_num_tokens
)
else
:
ret
.
positions
,
ret
.
extend_start_loc
=
compute_position_torch
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
)
ret
.
extend_prefix_lens_cpu
=
batch
.
extend_prefix_lens
ret
.
extend_seq_lens_cpu
=
batch
.
extend_seq_lens
ret
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
62c516ac
...
...
@@ -40,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
Sampler
...
...
@@ -570,6 +571,8 @@ class ModelRunner:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
...
...
python/sglang/srt/server_args.py
View file @
62c516ac
...
...
@@ -180,15 +180,21 @@ class ServerArgs:
else
:
self
.
cuda_graph_max_bs
=
160
# Set kernel backends
if
not
is_flashinfer_available
():
self
.
attention_backend
=
"triton"
self
.
sampling_backend
=
"pytorch"
# Choose kernel backends
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"flashinfer"
self
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
)
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
"flashinfer"
self
.
sampling_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"pytorch"
)
if
self
.
attention_backend
==
"torch_native"
:
logger
.
info
(
"Cuda graph is disabled because of using torch native attention backend"
)
self
.
disable_cuda_graph
=
True
# Others
if
self
.
enable_dp_attention
:
...
...
@@ -586,7 +592,7 @@ class ServerArgs:
parser
.
add_argument
(
"--attention-backend"
,
type
=
str
,
choices
=
[
"flashinfer"
,
"triton"
],
choices
=
[
"flashinfer"
,
"triton"
,
"torch_native"
],
default
=
ServerArgs
.
attention_backend
,
help
=
"Choose the kernels for attention layers."
,
)
...
...
test/srt/run_suite.py
View file @
62c516ac
...
...
@@ -34,6 +34,7 @@ suites = {
"test_srt_endpoint.py"
,
"test_torch_compile.py"
,
"test_torch_compile_moe.py"
,
"test_torch_native_attention_backend.py"
,
"test_torchao.py"
,
"test_triton_attention_kernels.py"
,
"test_triton_attention_backend.py"
,
...
...
test/srt/test_torch_native_attention_backend.py
0 → 100644
View file @
62c516ac
"""
Usage:
python3 -m unittest test_triton_attention_backend.TestTritonAttnBackend.test_mmlu
"""
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
is_in_ci
,
popen_launch_server
,
run_bench_one_batch
,
)
class
TestTorchNativeAttnBackend
(
unittest
.
TestCase
):
def
test_latency
(
self
):
output_throughput
=
run_bench_one_batch
(
DEFAULT_MODEL_NAME_FOR_TEST
,
[
"--attention-backend"
,
"torch_native"
],
)
if
is_in_ci
():
# Torch native backend is expected to be slower
assert
output_throughput
>
50
,
f
"
{
output_throughput
=
}
"
def
test_mmlu
(
self
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--attention-backend"
,
"torch_native"
],
)
try
:
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
finally
:
kill_process_tree
(
process
.
pid
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment