Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
189ae231
Unverified
Commit
189ae231
authored
May 04, 2023
by
Woosuk Kwon
Committed by
GitHub
May 04, 2023
Browse files
Use dtype from model config & Add Dolly V2 (#63)
parent
e548c148
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
7 deletions
+33
-7
cacheflow/master/server.py
cacheflow/master/server.py
+5
-1
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+28
-6
No files found.
cacheflow/master/server.py
View file @
189ae231
...
...
@@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# NOTE(woosuk): FlashAttention does not support float32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'bfloat16'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
...
...
cacheflow/models/model_utils.py
View file @
189ae231
from
typing
import
Union
,
Optional
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoConfig
from
transformers
import
PretrainedConfig
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPT2MemoryAnalyzer
...
...
@@ -22,6 +23,7 @@ _MODELS = {
'opt'
:
OPTForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
'pythia'
:
GPTNeoXForCausalLM
,
'dolly-v2'
:
GPTNeoXForCausalLM
,
}
_MEMORY_ANALYZERS
=
{
...
...
@@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = {
'opt'
:
OPTMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
'pythia'
:
GPTNeoXMemoryAnalyzer
,
'dolly-v2'
:
GPTNeoXMemoryAnalyzer
,
}
def
_get_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
)
->
torch
.
dtype
:
config_dtype
:
torch
.
dtype
=
getattr
(
config
,
'torch_dtype'
,
torch
.
float32
)
if
dtype
==
'default'
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
else
:
torch_dtype
=
get_torch_dtype
(
dtype
)
if
torch_dtype
!=
config_dtype
and
config_dtype
!=
torch
.
float32
:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise
ValueError
(
f
'Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model.'
)
return
torch_dtype
def
get_model
(
model_name
:
str
,
dtype
:
Union
[
torch
.
dtype
,
str
]
,
dtype
:
str
,
cache_dir
:
Optional
[
str
],
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
)
->
nn
.
Module
:
torch_dtype
=
get_torch_dtype
(
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
for
model_class_name
,
model_class
in
_MODELS
.
items
():
if
model_class_name
in
model_name
:
if
use_dummy_weights
:
...
...
@@ -66,12 +87,13 @@ def get_model(
def
get_memory_analyzer
(
model_name
:
str
,
block_size
:
int
,
dtype
:
Union
[
torch
.
dtype
,
str
]
,
dtype
:
str
,
gpu_memory
:
int
,
cpu_memory
:
int
,
tensor_parallel_size
:
int
=
1
,
)
->
CacheFlowMemoryAnalyzer
:
torch_dtype
=
get_torch_dtype
(
dtype
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
for
model_class
,
memory_analyzer
in
_MEMORY_ANALYZERS
.
items
():
if
model_class
in
model_name
:
return
memory_analyzer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment