Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb1f28cb
Unverified
Commit
fb1f28cb
authored
Aug 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 12, 2024
Browse files
Clean up the comments and names under python/sglang/srt/layers (#1047)
parent
fb7421db
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
26 additions
and
1633 deletions
+26
-1633
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+2
-0
python/sglang/srt/layers/decode_attention.py
python/sglang/srt/layers/decode_attention.py
+9
-5
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+6
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+2
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+0
-884
python/sglang/srt/layers/prefill_attention.py
python/sglang/srt/layers/prefill_attention.py
+5
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+0
-64
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+0
-677
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-2
No files found.
python/sglang/srt/layers/activation.py
View file @
fb1f28cb
...
@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and
...
@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Fused operators for activation layers."""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
...
python/sglang/srt/layers/
token
_attention.py
→
python/sglang/srt/layers/
decode
_attention.py
View file @
fb1f28cb
...
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""
Memory-efficient attention for decoding.
"""
# Adapted from
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
...
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
...
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
tl
.
store
(
out_ptrs
,
acc
)
tl
.
store
(
out_ptrs
,
acc
)
def
_
token
_att_m_fwd
(
def
_
decode
_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
att_out
,
att_out
,
...
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
...
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
)
)
def
_
token
_softmax_reducev_fwd
(
def
_
decode
_softmax_reducev_fwd
(
logics
,
logics
,
v_buffer
,
v_buffer
,
o
,
o
,
...
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
...
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
)
)
def
token
_attention_fwd
(
def
decode
_attention_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
...
@@ -312,7 +316,7 @@ def token_attention_fwd(
...
@@ -312,7 +316,7 @@ def token_attention_fwd(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
)
_
token
_att_m_fwd
(
_
decode
_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
att_m
,
att_m
,
...
@@ -324,7 +328,7 @@ def token_attention_fwd(
...
@@ -324,7 +328,7 @@ def token_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_
token
_softmax_reducev_fwd
(
_
decode
_softmax_reducev_fwd
(
att_m
,
att_m
,
v_buffer
,
v_buffer
,
o
,
o
,
...
...
python/sglang/srt/layers/extend_attention.py
View file @
fb1f28cb
...
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
...
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""
Memory-efficient attention for prefill.
It supporst page size = 1 and prefill with KV cache (i.e. extend).
"""
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.
context_flash
attention
_nopad
import
context_attention_fwd
from
sglang.srt.layers.
prefill_
attention
import
context_attention_fwd
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/layernorm.py
View file @
fb1f28cb
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Fused operators for normalization layers."""
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
...
...
python/sglang/srt/layers/linear.py
deleted
100644 → 0
View file @
fb7421db
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/
context_flash
attention
_nopad
.py
→
python/sglang/srt/layers/
prefill_
attention.py
View file @
fb1f28cb
...
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import
torch
import
torch
...
...
python/sglang/srt/layers/quantization/__init__.py
deleted
100644 → 0
View file @
fb7421db
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
}
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
QUANTIZATION_METHODS
[
quantization
]
__all__
=
[
"QuantizationConfig"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
python/sglang/srt/layers/quantization/fp8.py
deleted
100644 → 0
View file @
fb7421db
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/radix_attention.py
View file @
fb1f28cb
...
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
...
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
from
torch
import
nn
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
...
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
...
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
token
_attention_fwd
(
decode
_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_head_dim
),
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment