Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3a6e8b6d
Unverified
Commit
3a6e8b6d
authored
Sep 10, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 10, 2024
Browse files
[Minor] move triton attention kernels into a separate folder (#1379)
parent
fbb4754c
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
24 additions
and
15 deletions
+24
-15
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+0
-0
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+2
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+1
-1
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+1
-3
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+11
-4
python/sglang/srt/layers/triton_attention/decode_attention.py
...on/sglang/srt/layers/triton_attention/decode_attention.py
+0
-0
python/sglang/srt/layers/triton_attention/extend_attention.py
...on/sglang/srt/layers/triton_attention/extend_attention.py
+1
-1
python/sglang/srt/layers/triton_attention/prefill_attention.py
...n/sglang/srt/layers/triton_attention/prefill_attention.py
+0
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-2
scripts/deprecated/test_flashinfer.py
scripts/deprecated/test_flashinfer.py
+4
-1
No files found.
python/sglang/bench_latency.py
View file @
3a6e8b6d
...
...
@@ -57,9 +57,9 @@ import pandas as pd
import
torch
import
torch.distributed
as
dist
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
...
...
python/sglang/srt/model_config.py
→
python/sglang/srt/
configs/
model_config.py
View file @
3a6e8b6d
File moved
python/sglang/srt/constrained/__init__.py
View file @
3a6e8b6d
...
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""For constrained decoding."""
import
json
from
typing
import
Dict
,
Optional
,
Union
...
...
python/sglang/srt/conversation.py
View file @
3a6e8b6d
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Conversation templates."""
"""Conversation
chat
templates."""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
3a6e8b6d
...
...
@@ -16,11 +16,9 @@ limitations under the License.
"""Utilities for Huggingface Transformers."""
import
contextlib
import
functools
import
json
import
os
import
warnings
from
typing
import
AbstractSet
,
Collection
,
Dict
,
List
,
Literal
,
Optional
,
Type
,
Union
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
snapshot_download
from
transformers
import
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
3a6e8b6d
...
...
@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.
triton_attention.
decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.
triton_attention.
extend_attention
import
extend_attention_fwd
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
class
RadixAttention
(
nn
.
Module
):
"""
The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
def
__init__
(
self
,
num_heads
:
int
,
...
...
@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
self
.
sliding_window_size
=
sliding_window_size
if
sliding_window_size
else
-
1
# Choose backend
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
and
self
.
qk_head_dim
==
self
.
v_head_dim
...
...
@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
...
...
python/sglang/srt/layers/decode_attention.py
→
python/sglang/srt/layers/
triton_attention/
decode_attention.py
View file @
3a6e8b6d
File moved
python/sglang/srt/layers/extend_attention.py
→
python/sglang/srt/layers/
triton_attention/
extend_attention.py
View file @
3a6e8b6d
...
...
@@ -22,7 +22,7 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.prefill_attention
import
context_attention_fwd
from
sglang.srt.layers.
triton_attention.
prefill_attention
import
context_attention_fwd
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/prefill_attention.py
→
python/sglang/srt/layers/
triton_attention/
prefill_attention.py
View file @
3a6e8b6d
File moved
python/sglang/srt/managers/tp_worker.py
View file @
3a6e8b6d
...
...
@@ -29,6 +29,7 @@ import torch.distributed
import
torch.distributed
as
dist
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
...
...
@@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
)
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
3a6e8b6d
...
...
@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""M
odelRunner runs the forward passes of the model
s."""
"""M
eta data for a forward pas
s."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
3a6e8b6d
...
...
@@ -18,7 +18,6 @@ limitations under the License.
import
gc
import
importlib
import
importlib.resources
import
json
import
logging
import
pkgutil
from
functools
import
lru_cache
...
...
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
...
...
@@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
...
...
scripts/deprecated/test_flashinfer.py
View file @
3a6e8b6d
...
...
@@ -6,8 +6,11 @@ from flashinfer import (
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
,
redundant_attention
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.triton_attention.extend_attention
import
(
extend_attention_fwd
,
redundant_attention
,
)
flashinfer_prefill_wrapper
=
None
flashinfer_decode_wrapper
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment