Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0982ac5
Unverified
Commit
c0982ac5
authored
Jul 06, 2024
by
Mingyi
Committed by
GitHub
Jul 06, 2024
Browse files
Fix Llava model (#594)
parent
dc1b8bcf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
13 deletions
+18
-13
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+3
-4
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+2
-2
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+6
-0
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+6
-6
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
c0982ac5
import
json
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
import
numpy
as
np
import
requests
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.global_config
import
global_config
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SglArgument
,
SglSamplingParams
from
sglang.utils
import
encode_image_base64
,
find_printable_text
,
http_request
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.utils
import
find_printable_text
,
http_request
class
RuntimeEndpoint
(
BaseBackend
):
...
...
python/sglang/lang/interpreter.py
View file @
c0982ac5
...
...
@@ -523,9 +523,9 @@ class StreamExecutor:
self
,
sampling_params
=
sampling_params
)
self
.
variables
[
name
]
=
""
self
.
stream_var_event
[
name
].
set
()
self
.
variables
[
name
]
=
""
for
comp
,
meta_info
in
generator
:
self
.
text_
+=
comp
self
.
variables
[
name
]
+=
comp
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
c0982ac5
...
...
@@ -3,7 +3,7 @@
import
warnings
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
from
typing
import
List
,
Union
import
numpy
as
np
import
torch
...
...
@@ -31,7 +31,7 @@ class BaseFinishReason:
class
FINISH_MATCHED_TOKEN
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
int
|
List
[
int
]):
def
__init__
(
self
,
matched
:
Union
[
int
,
List
[
int
]
]
):
super
().
__init__
()
self
.
matched
=
matched
...
...
python/sglang/srt/model_config.py
View file @
c0982ac5
...
...
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name
=
config
.
architectures
[
0
]
if
class_name
.
startswith
(
"Llava"
)
and
class_name
.
endswith
(
"ForCausalLM"
):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
return
config
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
...
...
python/sglang/srt/models/gemma2.py
View file @
c0982ac5
# Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
Gemma2
Config
from
transformers
import
Pretrained
Config
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
def
__init__
(
self
,
layer_idx
:
int
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
...
...
@@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
def
__init__
(
self
,
layer_idx
:
int
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
...
...
@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
def
__init__
(
self
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
...
...
@@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment