Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
badf3fa0
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c4a3b09a36fb22b949dc7d56f447206d5fd3b0d5"
Unverified
Commit
badf3fa0
authored
Jun 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 27, 2024
Browse files
Expose dtype argument (#569)
parent
945aa9be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
21 deletions
+39
-21
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-5
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+31
-15
No files found.
python/sglang/srt/managers/controller/model_runner.py
View file @
badf3fa0
...
@@ -6,7 +6,7 @@ import logging
...
@@ -6,7 +6,7 @@ import logging
import
pkgutil
import
pkgutil
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Type
,
Any
from
typing
import
List
,
Optional
,
Type
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -119,7 +119,7 @@ class InputMetadata:
...
@@ -119,7 +119,7 @@ class InputMetadata:
head_dim
,
head_dim
,
1
,
1
,
pos_encoding_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
data_type
=
"float16"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
)
)
def
init_extend_args
(
self
):
def
init_extend_args
(
self
):
...
@@ -287,10 +287,11 @@ class ModelRunner:
...
@@ -287,10 +287,11 @@ class ModelRunner:
tokenizer
=
None
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
torch
.
float16
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
42
,
seed
=
42
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
self
.
dtype
=
vllm_model_config
.
dtype
if
self
.
model_config
.
model_overide_args
is
not
None
:
if
self
.
model_config
.
model_overide_args
is
not
None
:
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_overide_args
)
...
@@ -307,6 +308,7 @@ class ModelRunner:
...
@@ -307,6 +308,7 @@ class ModelRunner:
logger
.
info
(
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight end. "
f
"[gpu_id=
{
self
.
gpu_id
}
] Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
...
@@ -316,7 +318,7 @@ class ModelRunner:
...
@@ -316,7 +318,7 @@ class ModelRunner:
)
)
head_dim
=
self
.
model_config
.
head_dim
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
1
-
self
.
mem_fraction_static
)
)
...
@@ -337,7 +339,7 @@ class ModelRunner:
...
@@ -337,7 +339,7 @@ class ModelRunner:
)
)
self
.
token_to_kv_pool
=
TokenToKVPool
(
self
.
token_to_kv_pool
=
TokenToKVPool
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
dtype
=
torch
.
float16
,
dtype
=
self
.
dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
badf3fa0
...
@@ -120,7 +120,7 @@ class ModelTpServer:
...
@@ -120,7 +120,7 @@ class ModelTpServer:
f
"[gpu_id=
{
self
.
gpu_id
}
] "
f
"[gpu_id=
{
self
.
gpu_id
}
] "
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
,
"
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
)
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
logger
.
info
(
logger
.
info
(
...
...
python/sglang/srt/server_args.py
View file @
badf3fa0
...
@@ -11,12 +11,13 @@ class ServerArgs:
...
@@ -11,12 +11,13 @@ class ServerArgs:
# Model and tokenizer
# Model and tokenizer
model_path
:
str
model_path
:
str
tokenizer_path
:
Optional
[
str
]
=
None
tokenizer_path
:
Optional
[
str
]
=
None
load_format
:
str
=
"auto"
tokenizer_mode
:
str
=
"auto"
tokenizer_mode
:
str
=
"auto"
chat_template
:
Optional
[
str
]
=
None
load_format
:
str
=
"auto"
dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
chat_template
:
Optional
[
str
]
=
None
# Port
# Port
host
:
str
=
"127.0.0.1"
host
:
str
=
"127.0.0.1"
...
@@ -107,6 +108,15 @@ class ServerArgs:
...
@@ -107,6 +108,15 @@ class ServerArgs:
default
=
[],
default
=
[],
help
=
"The additional ports specified for the server."
,
help
=
"The additional ports specified for the server."
,
)
)
parser
.
add_argument
(
"--tokenizer-mode"
,
type
=
str
,
default
=
ServerArgs
.
tokenizer_mode
,
choices
=
[
"auto"
,
"slow"
],
help
=
"Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--load-format"
,
"--load-format"
,
type
=
str
,
type
=
str
,
...
@@ -124,20 +134,20 @@ class ServerArgs:
...
@@ -124,20 +134,20 @@ class ServerArgs:
"which is mainly for profiling."
,
"which is mainly for profiling."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer-mode"
,
"--dtype"
,
type
=
str
,
default
=
ServerArgs
.
tokenizer_mode
,
choices
=
[
"auto"
,
"slow"
],
help
=
"Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer."
,
)
parser
.
add_argument
(
"--chat-template"
,
type
=
str
,
type
=
str
,
default
=
ServerArgs
.
chat_template
,
default
=
ServerArgs
.
dtype
,
help
=
"The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server"
,
choices
=
[
)
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
],
help
=
'Data type for model weights and activations.
\n\n
'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.
\n
'
'* "half" for FP16. Recommended for AWQ quantization.
\n
'
'* "float16" is the same as "half".
\n
'
'* "bfloat16" for a balance between precision and range.
\n
'
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--trust-remote-code"
,
"--trust-remote-code"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -155,6 +165,12 @@ class ServerArgs:
...
@@ -155,6 +165,12 @@ class ServerArgs:
default
=
ServerArgs
.
quantization
,
default
=
ServerArgs
.
quantization
,
help
=
"The quantization method."
,
help
=
"The quantization method."
,
)
)
parser
.
add_argument
(
"--chat-template"
,
type
=
str
,
default
=
ServerArgs
.
chat_template
,
help
=
"The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mem-fraction-static"
,
"--mem-fraction-static"
,
type
=
float
,
type
=
float
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment