Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
caaad53b
"vscode:/vscode.git/clone" did not exist on "b1182bcf211c48379e1f36e7558e8c9dbfe3673a"
Unverified
Commit
caaad53b
authored
Jul 20, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 20, 2024
Browse files
Support gpt-bigcode model class (#681)
parent
69d19188
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
341 additions
and
12 deletions
+341
-12
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-5
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+29
-6
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+3
-1
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+282
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+17
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
caaad53b
...
...
@@ -34,12 +34,11 @@ class LogitProcessorOutput:
@
dataclasses
.
dataclass
class
LogitsMetadata
:
forward_mode
:
ForwardMode
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
# For logprobs
return_logprob
:
bool
top_logprobs_nums
:
List
[
int
]
extend_seq_lens
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
=
None
top_logprobs_nums
:
List
[
int
]
=
None
@
classmethod
def
from_input_metadata
(
cls
,
input_metadata
:
InputMetadata
):
...
...
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
caaad53b
...
...
@@ -6,6 +6,7 @@ import torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
...
...
@@ -16,8 +17,28 @@ from sglang.srt.managers.controller.infer_batch import (
)
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
=
False
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
else
:
sub
.
_forward_method
=
sub
.
forward_native
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
def
get_forward
(
model
:
torch
.
nn
.
Module
,
use_torch
:
bool
):
if
use_torch
:
_to_torch
(
model
,
reverse
=
False
)
return
torch
.
compile
(
model
.
forward
,
mode
=
"max-autotune-no-cudagraphs"
)
else
:
_to_torch
(
model
,
reverse
=
True
)
return
model
.
forward
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
):
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
,
use_torch_compile
):
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
input_buffers
=
{}
...
...
@@ -55,6 +76,8 @@ class CudaGraphRunner:
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
...
...
@@ -63,18 +86,19 @@ class CudaGraphRunner:
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
forward
=
get_forward
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
)
(
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
,
)
=
self
.
capture_one_batch_size
(
bs
)
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
):
def
capture_one_batch_size
(
self
,
bs
,
forward
):
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
...
...
@@ -127,9 +151,8 @@ class CudaGraphRunner:
skip_flashinfer_init
=
True
,
)
input_metadata
.
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
return
self
.
model_runner
.
model
.
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
run_once
()
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
caaad53b
...
...
@@ -244,7 +244,9 @@ class ModelRunner:
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
17
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
)
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
),
use_torch_compile
=
self
.
server_args
.
enable_torch_compile
,
)
try
:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
...
...
python/sglang/srt/models/gpt_bigcode.py
0 → 100644
View file @
caaad53b
# Adapted from:
# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GPTBigCodeConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.infer_batch
import
InputMetadata
class
GPTBigCodeAttention
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
:
int
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
total_num_heads
=
config
.
num_attention_heads
self
.
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
()
assert
total_num_heads
%
self
.
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
total_num_heads
//
self
.
tensor_model_parallel_world_size
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
multi_query
=
config
.
multi_query
if
self
.
multi_query
:
total_num_kv_heads
=
1
self
.
num_kv_heads
=
1
else
:
total_num_kv_heads
=
total_num_heads
self
.
num_kv_heads
=
self
.
num_heads
self
.
kv_dim
=
self
.
head_dim
*
self
.
num_kv_heads
self
.
c_attn
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
total_num_heads
,
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
scaling
=
self
.
scale
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
(
[
self
.
hidden_size
//
self
.
tensor_model_parallel_world_size
,
self
.
kv_dim
,
self
.
kv_dim
,
],
dim
=-
1
,
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
attn_output
,
_
=
self
.
c_proj
(
attn_output
)
return
attn_output
class
GPTBigMLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
:
int
,
config
:
GPTBigCodeConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
c_proj
(
hidden_states
)
return
hidden_states
class
GPTBigCodeBlock
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
:
int
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTBigCodeAttention
(
layer_id
,
config
,
cache_config
,
quant_config
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPTBigMLP
(
inner_dim
,
config
,
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
)
# residual connection
hidden_states
=
attn_output
+
residual
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
# residual connection
hidden_states
=
residual
+
feed_forward_hidden_states
return
hidden_states
class
GPTBigCodeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
assert
not
config
.
add_cross_attention
self
.
embed_dim
=
config
.
hidden_size
lora_vocab
=
(
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
wte
=
VocabParallelEmbedding
(
self
.
vocab_size
,
self
.
embed_dim
,
org_num_embeddings
=
config
.
vocab_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
(
[
GPTBigCodeBlock
(
i
,
config
,
cache_config
,
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
for
i
in
range
(
len
(
self
.
h
)):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
input_metadata
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
GPTBigCodeForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"c_attn"
:
[
"c_attn"
]}
supported_lora_modules
=
[
"c_fc"
,
"c_proj"
,
"wte"
,
"c_attn"
]
embedding_modules
=
{
"wte"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[]
def
__init__
(
self
,
config
:
GPTBigCodeConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
self
.
lm_head
=
self
.
transformer
.
wte
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"lm_head.weight"
in
name
:
continue
if
".attn.bias"
in
name
:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if
"c_attn.input_scale"
in
name
or
"c_attn.weight_scale"
in
name
:
weight_loader
(
param
,
loaded_weight
,
"q"
)
weight_loader
(
param
,
loaded_weight
,
"k"
)
weight_loader
(
param
,
loaded_weight
,
"v"
)
else
:
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
GPTBigCodeForCausalLM
python/sglang/srt/server.py
View file @
caaad53b
...
...
@@ -157,6 +157,19 @@ def _set_global_server_args(server_args: ServerArgs):
}
def
_set_torch_compile_config
():
# The following configurations are for torch compile optimizations
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
128
def
launch_server
(
server_args
:
ServerArgs
,
model_overide_args
:
Optional
[
dict
]
=
None
,
...
...
@@ -190,6 +203,10 @@ def launch_server(
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
if
server_args
.
enable_torch_compile
:
_set_torch_compile_config
()
_set_global_server_args
(
server_args
)
# Allocate ports
...
...
python/sglang/srt/server_args.py
View file @
caaad53b
...
...
@@ -55,6 +55,7 @@ class ServerArgs:
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_torch_compile
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
enable_p2p_check
:
bool
=
False
efficient_weight_load
:
bool
=
False
...
...
@@ -317,6 +318,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
parser
.
add_argument
(
"--enable-torch-compile"
,
action
=
"store_true"
,
help
=
"Optimize the model with torch.compile, experimental feature."
,
)
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment