Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ee88a7e5
Unverified
Commit
ee88a7e5
authored
Apr 08, 2023
by
Woosuk Kwon
Committed by
GitHub
Apr 08, 2023
Browse files
Add an option to use dummy model weights (#33)
parent
c267b1a0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
36 additions
and
8 deletions
+36
-8
benchmark/benchmark_latency.py
benchmark/benchmark_latency.py
+1
-0
cacheflow/http_frontend/fastapi_frontend.py
cacheflow/http_frontend/fastapi_frontend.py
+1
-0
cacheflow/master/server.py
cacheflow/master/server.py
+3
-0
cacheflow/models/llama.py
cacheflow/models/llama.py
+4
-0
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+17
-6
cacheflow/models/opt.py
cacheflow/models/opt.py
+4
-0
cacheflow/worker/controller.py
cacheflow/worker/controller.py
+2
-0
cacheflow/worker/worker.py
cacheflow/worker/worker.py
+3
-2
simple_server.py
simple_server.py
+1
-0
No files found.
benchmark/benchmark_latency.py
View file @
ee88a7e5
...
...
@@ -29,6 +29,7 @@ def main(args: argparse.Namespace):
server
=
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
use_dummy_weights
=
args
.
use_dummy_weights
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
...
...
cacheflow/http_frontend/fastapi_frontend.py
View file @
ee88a7e5
...
...
@@ -47,6 +47,7 @@ class FastAPIFrontend:
self
.
server
=
remote_server_class
.
remote
(
model
=
model
,
model_path
=
model_path
,
use_dummy_weights
=
False
,
pipeline_parallel_size
=
pipeline_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
block_size
=
block_size
,
...
...
cacheflow/master/server.py
View file @
ee88a7e5
...
...
@@ -16,6 +16,7 @@ class Server:
self
,
model
:
str
,
model_path
:
str
,
use_dummy_weights
:
bool
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
block_size
:
int
,
...
...
@@ -66,6 +67,7 @@ class Server:
dtype
=
dtype
,
seed
=
seed
,
model_path
=
model_path
,
use_dummy_weights
=
use_dummy_weights
,
max_num_batched_tokens
=
max_num_batched_tokens
,
)
self
.
controllers
.
append
(
controller
)
...
...
@@ -179,4 +181,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
0
,
help
=
'random seed'
)
parser
.
add_argument
(
'--swap-space'
,
type
=
int
,
default
=
20
,
help
=
'CPU swap space size (GiB) per GPU'
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
2560
,
help
=
'maximum number of batched tokens'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
return
parser
cacheflow/models/llama.py
View file @
ee88a7e5
...
...
@@ -286,3 +286,7 @@ class LlamaForCausalLM(nn.Module):
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
cacheflow/models/model_utils.py
View file @
ee88a7e5
...
...
@@ -28,18 +28,29 @@ def get_model(
model_name
:
str
,
dtype
:
Union
[
torch
.
dtype
,
str
],
path
:
str
,
use_dummy_weights
:
bool
,
)
->
nn
.
Module
:
torch_dtype
=
get_torch_dtype
(
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
for
model_class_name
,
model_class
in
_MODELS
.
items
():
if
model_class_name
in
model_name
:
# Download model weights if it's not cached.
weights_dir
=
model_class
.
get_weights
(
model_name
,
path
=
path
)
# Create a model instance.
model
=
model_class
(
config
)
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
weights_dir
)
if
use_dummy_weights
:
# Create a model instance.
# The weights will be initialized as empty tensors.
model
=
model_class
(
config
)
model
=
model
.
cuda
()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
model
.
initialize_dummy_weights
()
else
:
# Download model weights if it's not cached.
weights_dir
=
model_class
.
get_weights
(
model_name
,
path
=
path
)
# Create a model instance.
model
=
model_class
(
config
)
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
weights_dir
)
model
=
model
.
cuda
()
return
model
.
eval
(),
torch_dtype
raise
ValueError
(
f
'Unsupported model name:
{
model_name
}
'
)
...
...
cacheflow/models/opt.py
View file @
ee88a7e5
...
...
@@ -324,3 +324,7 @@ class OPTForCausalLM(nn.Module):
np
.
save
(
f
,
param
.
cpu
().
detach
().
numpy
())
return
path
def
initialize_dummy_weights
(
self
)
->
None
:
for
param
in
self
.
state_dict
().
values
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
cacheflow/worker/controller.py
View file @
ee88a7e5
...
...
@@ -27,6 +27,7 @@ class Controller:
dtype
:
str
,
seed
:
int
,
model_path
:
str
,
use_dummy_weights
:
bool
,
max_num_batched_tokens
:
int
,
)
->
None
:
self
.
stage_id
=
stage_id
...
...
@@ -58,6 +59,7 @@ class Controller:
tensor_parallel_size
=
tensor_parallel_size
,
pipeline_parallel_size
=
pipeline_parallel_size
,
model_path
=
model_path
,
use_dummy_weights
=
use_dummy_weights
,
max_num_batched_tokens
=
max_num_batched_tokens
,
)
self
.
workers
.
append
(
worker
)
...
...
cacheflow/worker/worker.py
View file @
ee88a7e5
...
...
@@ -29,6 +29,7 @@ class Worker:
rank
:
int
,
world_size
:
int
,
model_path
:
str
,
use_dummy_weights
:
bool
,
max_num_batched_tokens
:
int
,
tensor_parallel_size
:
int
=
1
,
pipeline_parallel_size
:
int
=
1
,
...
...
@@ -43,8 +44,8 @@ class Worker:
set_random_seed
(
seed
)
# Initialize the model.
self
.
model
,
self
.
dtype
=
get_model
(
model_name
,
dtype
=
dtype
,
path
=
model_path
)
self
.
model
=
self
.
model
.
cuda
(
)
self
.
model
,
self
.
dtype
=
get_model
(
model_name
,
dtype
=
dtype
,
path
=
model_path
,
use_dummy_weights
=
use_dummy_weights
)
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
initialize_all_reduce_launcher
(
...
...
simple_server.py
View file @
ee88a7e5
...
...
@@ -22,6 +22,7 @@ def main(args: argparse.Namespace):
server
=
Server
(
model
=
args
.
model
,
model_path
=
args
.
model_path
,
use_dummy_weights
=
args
.
use_dummy_weights
,
pipeline_parallel_size
=
args
.
pipeline_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
block_size
=
args
.
block_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment