Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37963394
Unverified
Commit
37963394
authored
Sep 15, 2024
by
Ying Sheng
Committed by
GitHub
Sep 15, 2024
Browse files
[Feature] Support LoRA path renaming and add LoRA serving benchmarks (#1433)
parent
899cf5c4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
594 additions
and
62 deletions
+594
-62
benchmark/lora/launch_server.py
benchmark/lora/launch_server.py
+53
-0
benchmark/lora/lora_bench.py
benchmark/lora/lora_bench.py
+485
-0
examples/runtime/lora.py
examples/runtime/lora.py
+37
-0
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+6
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-1
scripts/playground/lora/test_lora.py
scripts/playground/lora/test_lora.py
+0
-55
No files found.
benchmark/lora/launch_server.py
0 → 100644
View file @
37963394
import
argparse
import
os
NUM_LORAS
=
128
LORA_PATH
=
{
"base"
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
"lora"
:
"/home/ying/test_lora"
,
}
def
launch_server
(
args
):
base_path
=
LORA_PATH
[
"base"
]
lora_path
=
LORA_PATH
[
"lora"
]
max_loras_per_batch
=
4
if
args
.
base_only
:
cmd
=
f
"python -m sglang.launch_server --model
{
base_path
}
"
else
:
cmd
=
f
"python -m sglang.launch_server --model
{
base_path
}
--lora-paths "
for
i
in
range
(
NUM_LORAS
):
lora_name
=
f
"lora
{
i
}
"
cmd
+=
f
"
{
lora_name
}
=
{
lora_path
}
"
cmd
+=
f
"--disable-radix --disable-cuda-graph "
cmd
+=
f
"--max-loras-per-batch
{
args
.
max_loras_per_batch
}
"
cmd
+=
f
"--max-running-requests
{
args
.
max_running_requests
}
"
print
(
cmd
)
os
.
system
(
cmd
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-loras"
,
type
=
int
,
default
=
128
,
)
parser
.
add_argument
(
"--base-only"
,
action
=
"store_true"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
type
=
int
,
default
=
8
,
)
parser
.
add_argument
(
"--max-running-requests"
,
type
=
int
,
default
=
8
,
)
args
=
parser
.
parse_args
()
launch_server
(
args
)
benchmark/lora/lora_bench.py
0 → 100644
View file @
37963394
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
argparse
import
asyncio
import
json
import
os
import
random
import
resource
import
sys
import
time
import
traceback
import
warnings
from
argparse
import
ArgumentParser
from
dataclasses
import
dataclass
,
field
from
datetime
import
datetime
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
aiohttp
import
numpy
as
np
import
requests
from
launch_server
import
LORA_PATH
,
NUM_LORAS
from
tqdm.asyncio
import
tqdm
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerBase
,
PreTrainedTokenizerFast
,
)
from
sglang.bench_serving
import
(
AIOHTTP_TIMEOUT
,
SHAREGPT_URL
,
BenchmarkMetrics
,
RequestFuncInput
,
RequestFuncOutput
,
calculate_metrics
,
check_chat_template
,
get_model
,
get_request
,
get_tokenizer
,
parse_request_rate_range
,
remove_prefix
,
sample_random_requests
,
)
global
args
# set ignore_eos True by default
async
def
async_request_openai_completions
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
# assert api_url.endswith(
# "completions"
# ), "OpenAI Completions API URL must end with 'completions'."
prompt
=
request_func_input
.
prompt
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
# payload = {
# "model": request_func_input.model,
# "prompt": prompt,
# "temperature": 0.0,
# "best_of": 1,
# "max_tokens": request_func_input.output_len,
# "stream": not args.disable_stream,
# "ignore_eos": not args.disable_ignore_eos,
# **request_func_input.extra_request_body,
# }
# headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
if
args
.
base_only
:
payload
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"max_new_tokens"
:
request_func_input
.
output_len
},
}
else
:
payload
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"max_new_tokens"
:
request_func_input
.
output_len
},
"lora_path"
:
f
"lora
{
random
.
randint
(
0
,
NUM_LORAS
-
1
)
}
"
,
}
headers
=
{
"Authorization"
:
""
}
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
remove_prefix
(
chunk_bytes
.
decode
(
"utf-8"
),
"data: "
)
latency
=
time
.
perf_counter
()
-
st
if
chunk
==
"[DONE]"
:
pass
else
:
data
=
json
.
loads
(
chunk
)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if
data
[
"text"
]:
# if data["choices"][0]["text"]:
timestamp
=
time
.
perf_counter
()
# First token
if
ttft
==
0.0
:
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
most_recent_timestamp
=
timestamp
# generated_text += data["choices"][0]["text"]
generated_text
+=
data
[
"text"
]
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
latency
output
.
output_len
=
request_func_input
.
output_len
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
ASYNC_REQUEST_FUNCS
=
{
"sglang"
:
async_request_openai_completions
,
}
async
def
benchmark
(
backend
:
str
,
api_url
:
str
,
model_id
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
request_rate
:
float
,
disable_tqdm
:
bool
,
extra_request_body
:
Dict
[
str
,
Any
],
):
if
backend
in
ASYNC_REQUEST_FUNCS
:
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
else
:
raise
ValueError
(
f
"Unknown backend:
{
backend
}
"
)
print
(
"Starting initial single prompt test run..."
)
test_prompt
,
test_prompt_len
,
test_output_len
=
input_requests
[
0
]
test_input
=
RequestFuncInput
(
model
=
model_id
,
prompt
=
test_prompt
,
api_url
=
api_url
,
prompt_len
=
test_prompt_len
,
output_len
=
test_output_len
,
extra_request_body
=
extra_request_body
,
)
test_output
=
await
request_func
(
request_func_input
=
test_input
)
if
not
test_output
.
success
:
raise
ValueError
(
"Initial test run failed - Please make sure benchmark arguments "
f
"are correctly specified. Error:
{
test_output
.
error
}
"
)
else
:
print
(
"Initial test run completed. Starting main benchmark run..."
)
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
benchmark_start_time
=
time
.
perf_counter
()
tasks
:
List
[
asyncio
.
Task
]
=
[]
async
for
request
in
get_request
(
input_requests
,
request_rate
):
prompt
,
prompt_len
,
output_len
=
request
request_func_input
=
RequestFuncInput
(
model
=
model_id
,
prompt
=
prompt
,
api_url
=
api_url
,
prompt_len
=
prompt_len
,
output_len
=
output_len
,
extra_request_body
=
extra_request_body
,
)
tasks
.
append
(
asyncio
.
create_task
(
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
)
)
outputs
:
List
[
RequestFuncOutput
]
=
await
asyncio
.
gather
(
*
tasks
)
if
pbar
is
not
None
:
pbar
.
close
()
benchmark_duration
=
time
.
perf_counter
()
-
benchmark_start_time
metrics
,
output_lens
=
calculate_metrics
(
input_requests
=
input_requests
,
outputs
=
outputs
,
dur_s
=
benchmark_duration
,
tokenizer
=
tokenizer
,
backend
=
backend
,
)
print
(
"
\n
{s:{c}^{n}}"
.
format
(
s
=
" Serving Benchmark Result "
,
n
=
50
,
c
=
"="
))
print
(
"{:<40} {:<10}"
.
format
(
"Backend:"
,
backend
))
print
(
"{:<40} {:<10}"
.
format
(
"Traffic request rate:"
,
request_rate
))
print
(
"{:<40} {:<10}"
.
format
(
"Successful requests:"
,
metrics
.
completed
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Benchmark duration (s):"
,
benchmark_duration
))
print
(
"{:<40} {:<10}"
.
format
(
"Total input tokens:"
,
metrics
.
total_input
))
print
(
"{:<40} {:<10}"
.
format
(
"Total generated tokens:"
,
metrics
.
total_output
))
print
(
"{:<40} {:<10}"
.
format
(
"Total generated tokens (retokenized):"
,
metrics
.
total_output_retokenized
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Request throughput (req/s):"
,
metrics
.
request_throughput
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Input token throughput (tok/s):"
,
metrics
.
input_throughput
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Output token throughput (tok/s):"
,
metrics
.
output_throughput
)
)
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"End-to-End Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean E2E Latency (ms):"
,
metrics
.
mean_e2e_latency_ms
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median E2E Latency (ms):"
,
metrics
.
median_e2e_latency_ms
)
)
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time to First Token"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TTFT (ms):"
,
metrics
.
mean_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TTFT (ms):"
,
metrics
.
median_ttft_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TTFT (ms):"
,
metrics
.
p99_ttft_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Time per Output Token (excl. 1st token)"
,
n
=
50
,
c
=
"-"
)
)
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean TPOT (ms):"
,
metrics
.
mean_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median TPOT (ms):"
,
metrics
.
median_tpot_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 TPOT (ms):"
,
metrics
.
p99_tpot_ms
))
print
(
"{s:{c}^{n}}"
.
format
(
s
=
"Inter-token Latency"
,
n
=
50
,
c
=
"-"
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Mean ITL (ms):"
,
metrics
.
mean_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"Median ITL (ms):"
,
metrics
.
median_itl_ms
))
print
(
"{:<40} {:<10.2f}"
.
format
(
"P99 ITL (ms):"
,
metrics
.
p99_itl_ms
))
print
(
"="
*
50
)
if
(
metrics
.
median_ttft_ms
is
not
None
and
metrics
.
mean_itl_ms
is
not
None
and
metrics
.
output_throughput
is
not
None
):
result
=
{
"backend"
:
args
.
backend
,
"request_rate"
:
request_rate
,
"total_input_tokens"
:
metrics
.
total_input
,
"total_output_tokens"
:
metrics
.
total_output
,
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
"median_ttft_ms"
:
metrics
.
median_ttft_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"output_throughput"
:
metrics
.
output_throughput
,
"random_input_len"
:
args
.
random_input_len
,
"random_output_len"
:
args
.
random_output_len
,
"random_range_ratio"
:
args
.
random_range_ratio
,
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
}
else
:
print
(
f
"Error running benchmark for request rate:
{
request_rate
}
"
)
print
(
"-"
*
30
)
# Determine output file name
if
args
.
output_file
:
output_file_name
=
args
.
output_file
else
:
now
=
datetime
.
now
().
strftime
(
"%m%d"
)
output_file_name
=
f
"
{
args
.
backend
}
_
{
now
}
_
{
args
.
num_prompts
}
_
{
args
.
random_input_len
}
_
{
args
.
random_output_len
}
.jsonl"
# Append results to a JSONL file
with
open
(
output_file_name
,
"a"
)
as
file
:
file
.
write
(
json
.
dumps
(
result
)
+
"
\n
"
)
result
=
{
"duration"
:
benchmark_duration
,
"completed"
:
metrics
.
completed
,
"total_input_tokens"
:
metrics
.
total_input
,
"total_output_tokens"
:
metrics
.
total_output
,
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"request_throughput"
:
metrics
.
request_throughput
,
"input_throughput"
:
metrics
.
input_throughput
,
"output_throughput"
:
metrics
.
output_throughput
,
"mean_ttft_ms"
:
metrics
.
mean_ttft_ms
,
"median_ttft_ms"
:
metrics
.
median_ttft_ms
,
"std_ttft_ms"
:
metrics
.
std_ttft_ms
,
"p99_ttft_ms"
:
metrics
.
p99_ttft_ms
,
"mean_tpot_ms"
:
metrics
.
mean_tpot_ms
,
"median_tpot_ms"
:
metrics
.
median_tpot_ms
,
"std_tpot_ms"
:
metrics
.
std_tpot_ms
,
"p99_tpot_ms"
:
metrics
.
p99_tpot_ms
,
"mean_itl_ms"
:
metrics
.
mean_itl_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"std_itl_ms"
:
metrics
.
std_itl_ms
,
"p99_itl_ms"
:
metrics
.
p99_itl_ms
,
"input_lens"
:
[
output
.
prompt_len
for
output
in
outputs
],
"output_lens"
:
output_lens
,
"ttfts"
:
[
output
.
ttft
for
output
in
outputs
],
"itls"
:
[
output
.
itl
for
output
in
outputs
],
"generated_texts"
:
[
output
.
generated_text
for
output
in
outputs
],
"errors"
:
[
output
.
error
for
output
in
outputs
],
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
}
return
result
def
run_benchmark
(
args_
:
argparse
.
Namespace
):
global
args
args
=
args_
# Set global environments
set_ulimit
()
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
# Set url
if
args
.
port
is
None
:
args
.
port
=
{
"sglang"
:
30000
,
}.
get
(
args
.
backend
,
30000
)
# api_url = (
# f"{args.base_url}/v1/completions"
# if args.base_url
# else f"http://{args.host}:{args.port}/v1/completions"
# )
api_url
=
(
f
"
{
args
.
base_url
}
/generate"
if
args
.
base_url
else
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
print
(
f
"
{
args
}
\n
"
)
# Read dataset
backend
=
args
.
backend
model_id
=
args
.
model
=
LORA_PATH
[
"base"
]
tokenizer_id
=
args
.
model
tokenizer
=
get_tokenizer
(
tokenizer_id
)
input_requests
=
sample_random_requests
(
input_len
=
args
.
random_input_len
,
output_len
=
args
.
random_output_len
,
num_prompts
=
args
.
num_prompts
,
range_ratio
=
args
.
random_range_ratio
,
tokenizer
=
tokenizer
,
dataset_path
=
""
,
)
return
asyncio
.
run
(
benchmark
(
backend
=
backend
,
api_url
=
api_url
,
model_id
=
model_id
,
tokenizer
=
tokenizer
,
input_requests
=
input_requests
,
request_rate
=
args
.
request_rate
,
disable_tqdm
=
False
,
extra_request_body
=
{},
)
)
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
<
target_soft_limit
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
print
(
f
"Fail to set RLIMIT_NOFILE:
{
e
}
"
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
description
=
"Benchmark the online lora serving throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
choices
=
list
(
ASYNC_REQUEST_FUNCS
.
keys
()),
default
=
"sglang"
,
help
=
"Must specify a backend, depending on the LLM Inference Engine."
,
)
parser
.
add_argument
(
"--base-url"
,
type
=
str
,
default
=
None
,
help
=
"Server or API base url if not using http host and port."
,
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Default host is 0.0.0.0."
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
help
=
"If not set, the default port is configured according to its default value for different LLM Inference Engines."
,
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
50
,
help
=
"Number of prompts to process. Default is 1000."
,
)
parser
.
add_argument
(
"--random-input-len"
,
type
=
int
,
default
=
1024
,
help
=
"Number of input tokens per request, used only for random dataset."
,
)
parser
.
add_argument
(
"--random-output-len"
,
type
=
int
,
default
=
128
,
help
=
"Number of output tokens per request, used only for random dataset."
,
)
parser
.
add_argument
(
"--random-range-ratio"
,
type
=
float
,
default
=
0.0
,
help
=
"Range of sampled ratio of input/output length, "
"used only for random dataset."
,
)
parser
.
add_argument
(
"--request-rate"
,
type
=
float
,
default
=
float
(
"inf"
),
help
=
"Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf."
,
)
parser
.
add_argument
(
"--base-only"
,
action
=
"store_true"
,
)
parser
.
add_argument
(
"--output-file"
,
type
=
str
,
help
=
"Output JSONL file name."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1
,
help
=
"The random seed."
)
args
=
parser
.
parse_args
()
run_benchmark
(
args
)
examples/runtime/lora.py
0 → 100644
View file @
37963394
# launch server
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
# send requests
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"]
import
json
import
requests
url
=
"http://127.0.0.1:30000"
json_data
=
{
"text"
:
[
"prompt 1"
,
"prompt 2"
,
"prompt 3"
,
"prompt 4"
,
"prompt 5"
,
"prompt 6"
,
"prompt 7"
,
],
"sampling_params"
:
{
"max_new_tokens"
:
32
},
"lora_path"
:
[
"/home/ying/test_lora"
,
"/home/ying/test_lora_1"
,
"/home/ying/test_lora_2"
,
"lora3"
,
"lora4"
,
"/home/ying/test_lora"
,
"/home/ying/test_lora_1"
,
],
}
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
json_data
,
)
print
(
json
.
dumps
(
response
.
json
()))
python/sglang/srt/lora/lora_manager.py
View file @
37963394
...
...
@@ -96,10 +96,10 @@ class LoRAManager:
# get configs and target modules
self
.
configs
=
{}
self
.
origin_target_modules
=
set
()
for
path
in
self
.
lora_paths
:
self
.
configs
[
path
]
=
LoRAConfig
(
path
)
for
name
,
path
in
self
.
lora_paths
.
items
()
:
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
origin_target_modules
=
set
(
self
.
origin_target_modules
)
|
set
(
self
.
configs
[
path
].
target_modules
self
.
configs
[
name
].
target_modules
)
self
.
target_modules
=
set
(
[
...
...
@@ -114,11 +114,11 @@ class LoRAManager:
# load all weights to cpu
self
.
loras
=
[]
self
.
lora_id
=
{}
for
path
in
self
.
lora_paths
:
self
.
lora_id
[
path
]
=
len
(
self
.
loras
)
for
name
in
self
.
lora_paths
.
keys
()
:
self
.
lora_id
[
name
]
=
len
(
self
.
loras
)
self
.
loras
.
append
(
LoRAAdapter
(
path
,
self
.
configs
[
path
],
self
.
base_hf_config
,
self
.
load_config
name
,
self
.
configs
[
name
],
self
.
base_hf_config
,
self
.
load_config
)
)
self
.
loras
[
-
1
].
initialize_weights
()
...
...
python/sglang/srt/server_args.py
View file @
37963394
...
...
@@ -24,6 +24,17 @@ from typing import List, Optional, Union
logger
=
logging
.
getLogger
(
__name__
)
class
LoRAPathAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
setattr
(
namespace
,
self
.
dest
,
{})
for
lora_path
in
values
:
if
"="
in
lora_path
:
name
,
path
=
lora_path
.
split
(
"="
,
1
)
getattr
(
namespace
,
self
.
dest
)[
name
]
=
path
else
:
getattr
(
namespace
,
self
.
dest
)[
lora_path
]
=
lora_path
@
dataclasses
.
dataclass
class
ServerArgs
:
# Model and tokenizer
...
...
@@ -532,7 +543,8 @@ class ServerArgs:
type
=
str
,
nargs
=
"*"
,
default
=
None
,
help
=
"The list of LoRA adapters."
,
action
=
LoRAPathAction
,
help
=
"The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}"
,
)
parser
.
add_argument
(
"--max-loras-per-batch"
,
...
...
scripts/playground/lora/test_lora.py
deleted
100644 → 0
View file @
899cf5c4
import
json
import
openai
import
requests
import
sglang
as
sgl
lora_path
=
"/home/ying/test_lora"
prompt_file
=
"/home/ying/test_prompt/dialogue_choice_prompts.json"
server_url
=
"http://127.0.0.1:30000"
client
=
openai
.
Client
(
base_url
=
server_url
+
"/v1"
,
api_key
=
"EMPTY"
)
# @sgl.function
# def generate(s, prompt):
# s += prompt
# s += sgl.gen("ans")
# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url))
def
generate
(
prompt
,
lora_path
):
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
{},
"return_logprob"
:
False
,
"logprob_start_len"
:
None
,
"top_logprobs_num"
:
None
,
"lora_path"
:
lora_path
,
}
response
=
requests
.
post
(
server_url
+
"/generate"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
with
open
(
prompt_file
,
"r"
)
as
f
:
samples
=
json
.
load
(
f
)
for
sample
in
samples
[:
1
]:
assert
sample
[
0
][
"role"
]
==
"user"
prompt
=
sample
[
0
][
"content"
]
assert
sample
[
1
][
"role"
]
==
"assistant"
ref
=
sample
[
1
][
"content"
]
state
=
generate
(
prompt
,
lora_path
)
print
(
"================================"
)
print
(
ref
)
print
(
"--------------------------------"
)
# print(state["ans"])
print
(
state
)
print
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment