Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
189ae231
Unverified
Commit
189ae231
authored
May 04, 2023
by
Woosuk Kwon
Committed by
GitHub
May 04, 2023
Browse files
Use dtype from model config & Add Dolly V2 (#63)
parent
e548c148
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
7 deletions
+33
-7
cacheflow/master/server.py
cacheflow/master/server.py
+5
-1
cacheflow/models/model_utils.py
cacheflow/models/model_utils.py
+28
-6
No files found.
cacheflow/master/server.py
View file @
189ae231
...
...
@@ -214,7 +214,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# NOTE(woosuk): FlashAttention does not support float32.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'bfloat16'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'half'
,
'bfloat16'
],
help
=
(
'data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
))
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
...
...
cacheflow/models/model_utils.py
View file @
189ae231
from
typing
import
Union
,
Optional
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoConfig
from
transformers
import
PretrainedConfig
from
cacheflow.models.memory_analyzer
import
CacheFlowMemoryAnalyzer
from
cacheflow.models.memory_analyzer
import
GPT2MemoryAnalyzer
...
...
@@ -22,6 +23,7 @@ _MODELS = {
'opt'
:
OPTForCausalLM
,
'stablelm'
:
GPTNeoXForCausalLM
,
'pythia'
:
GPTNeoXForCausalLM
,
'dolly-v2'
:
GPTNeoXForCausalLM
,
}
_MEMORY_ANALYZERS
=
{
...
...
@@ -30,19 +32,38 @@ _MEMORY_ANALYZERS = {
'opt'
:
OPTMemoryAnalyzer
,
'stablelm'
:
GPTNeoXMemoryAnalyzer
,
'pythia'
:
GPTNeoXMemoryAnalyzer
,
'dolly-v2'
:
GPTNeoXMemoryAnalyzer
,
}
def
_get_dtype
(
config
:
PretrainedConfig
,
dtype
:
str
)
->
torch
.
dtype
:
config_dtype
:
torch
.
dtype
=
getattr
(
config
,
'torch_dtype'
,
torch
.
float32
)
if
dtype
==
'default'
:
if
config_dtype
==
torch
.
float32
:
# Following the common practice, we use float16 for float32 models.
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
else
:
torch_dtype
=
get_torch_dtype
(
dtype
)
if
torch_dtype
!=
config_dtype
and
config_dtype
!=
torch
.
float32
:
# TODO(woosuk): Allow using float16 for bfloat16 models and
# vice versa. Print a warning message and continue.
raise
ValueError
(
f
'Cannot use
{
torch_dtype
}
for
{
config_dtype
}
model.'
)
return
torch_dtype
def
get_model
(
model_name
:
str
,
dtype
:
Union
[
torch
.
dtype
,
str
]
,
dtype
:
str
,
cache_dir
:
Optional
[
str
],
use_dummy_weights
:
bool
,
use_np_cache
:
bool
,
)
->
nn
.
Module
:
torch_dtype
=
get_torch_dtype
(
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
torch
.
set_default_dtype
(
torch_dtype
)
for
model_class_name
,
model_class
in
_MODELS
.
items
():
if
model_class_name
in
model_name
:
if
use_dummy_weights
:
...
...
@@ -66,12 +87,13 @@ def get_model(
def
get_memory_analyzer
(
model_name
:
str
,
block_size
:
int
,
dtype
:
Union
[
torch
.
dtype
,
str
]
,
dtype
:
str
,
gpu_memory
:
int
,
cpu_memory
:
int
,
tensor_parallel_size
:
int
=
1
,
)
->
CacheFlowMemoryAnalyzer
:
torch_dtype
=
get_torch_dtype
(
dtype
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
torch_dtype
=
_get_dtype
(
config
,
dtype
)
for
model_class
,
memory_analyzer
in
_MEMORY_ANALYZERS
.
items
():
if
model_class
in
model_name
:
return
memory_analyzer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment