Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0982ac5
Unverified
Commit
c0982ac5
authored
Jul 06, 2024
by
Mingyi
Committed by
GitHub
Jul 06, 2024
Browse files
Fix Llava model (#594)
parent
dc1b8bcf
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
13 deletions
+18
-13
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+3
-4
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+2
-2
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+6
-0
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+6
-6
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
c0982ac5
import
json
import
json
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
requests
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SglArgument
,
SglSamplingParams
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.utils
import
encode_image_base64
,
find_printable_text
,
http_request
from
sglang.utils
import
find_printable_text
,
http_request
class
RuntimeEndpoint
(
BaseBackend
):
class
RuntimeEndpoint
(
BaseBackend
):
...
...
python/sglang/lang/interpreter.py
View file @
c0982ac5
...
@@ -523,9 +523,9 @@ class StreamExecutor:
...
@@ -523,9 +523,9 @@ class StreamExecutor:
self
,
sampling_params
=
sampling_params
self
,
sampling_params
=
sampling_params
)
)
self
.
variables
[
name
]
=
""
self
.
stream_var_event
[
name
].
set
()
self
.
stream_var_event
[
name
].
set
()
self
.
variables
[
name
]
=
""
for
comp
,
meta_info
in
generator
:
for
comp
,
meta_info
in
generator
:
self
.
text_
+=
comp
self
.
text_
+=
comp
self
.
variables
[
name
]
+=
comp
self
.
variables
[
name
]
+=
comp
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
c0982ac5
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
List
from
typing
import
List
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -31,7 +31,7 @@ class BaseFinishReason:
...
@@ -31,7 +31,7 @@ class BaseFinishReason:
class
FINISH_MATCHED_TOKEN
(
BaseFinishReason
):
class
FINISH_MATCHED_TOKEN
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
int
|
List
[
int
]):
def
__init__
(
self
,
matched
:
Union
[
int
,
List
[
int
]
]
):
super
().
__init__
()
super
().
__init__
()
self
.
matched
=
matched
self
.
matched
=
matched
...
...
python/sglang/srt/model_config.py
View file @
c0982ac5
...
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
...
@@ -115,6 +115,12 @@ def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
No op for pure text models.
"""
"""
class_name
=
config
.
architectures
[
0
]
if
class_name
.
startswith
(
"Llava"
)
and
class_name
.
endswith
(
"ForCausalLM"
):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
return
config
if
hasattr
(
config
,
"text_config"
):
if
hasattr
(
config
,
"text_config"
):
# The code operates under the assumption that text_config should have
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# `num_attention_heads` (among others). Assert here to fail early
...
...
python/sglang/srt/models/gemma2.py
View file @
c0982ac5
# Adapted from:
# Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Gemma2
Config
from
transformers
import
Pretrained
Config
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
...
@@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
...
@@ -131,7 +131,7 @@ class Gemma2Attention(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
layer_idx
:
int
,
layer_idx
:
int
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
...
@@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -222,7 +222,7 @@ class Gemma2DecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
layer_idx
:
int
,
layer_idx
:
int
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
...
@@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -369,7 +369,7 @@ class Gemma2ForCausalLM(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Gemma2
Config
,
config
:
Pretrained
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment