Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
217ee621
Commit
217ee621
authored
Dec 05, 2024
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev
parents
f0021a4d
3f78216a
Changes
68
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
172 additions
and
11 deletions
+172
-11
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+33
-5
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+30
-5
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+5
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+5
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+5
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+4
-0
vllm/triton_utils/custom_cache_manager.py
vllm/triton_utils/custom_cache_manager.py
+1
-1
vllm/utils.py
vllm/utils.py
+89
-0
No files found.
vllm/model_executor/models/llama.py
View file @
217ee621
...
@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quant_config
,
quant_config
,
lora_config
=
lora_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
prefix
=
"model"
)
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
forward
(
def
forward
(
self
,
self
,
...
@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
#当为triton支持推理的时候不能进行处理
if
self
.
quant_method
==
"compressed_tensors"
:
if
self
.
quant_method
==
"compressed_tensors"
:
lay_key_words
=
[
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.qkv_proj.weight"
,
...
@@ -656,15 +660,39 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -656,15 +660,39 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"mlp.down_proj.weight"
,
"mlp.down_proj.weight"
,
]
]
combined_words
=
"|"
.
join
(
lay_key_words
)
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
n
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
# If this function is called, it should always initialize KV cache scale
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
# make sure to leave KV cache scale factors in a known good (dummy) state
...
...
vllm/model_executor/models/qwen.py
View file @
217ee621
...
@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
,
W8a8GetCacheJSON
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
...
@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
_get_image_input_type
(
def
_get_image_input_type
(
self
,
self
,
...
@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
...
@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
"mlp.c_proj.weight"
,
"mlp.c_proj.weight"
,
]
]
combined_words
=
"|"
.
join
(
lay_key_words
)
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
n
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
vllm/platforms/cpu.py
View file @
217ee621
import
psutil
import
torch
import
torch
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
...
@@ -10,6 +11,10 @@ class CpuPlatform(Platform):
...
@@ -10,6 +11,10 @@ class CpuPlatform(Platform):
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
"cpu"
return
"cpu"
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
return
psutil
.
virtual_memory
().
total
@
classmethod
@
classmethod
def
inference_mode
(
cls
):
def
inference_mode
(
cls
):
return
torch
.
no_grad
()
return
torch
.
no_grad
()
vllm/platforms/interface.py
View file @
217ee621
...
@@ -83,6 +83,11 @@ class Platform:
...
@@ -83,6 +83,11 @@ class Platform:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
"""Get the total memory of a device in bytes."""
raise
NotImplementedError
@
classmethod
@
classmethod
def
inference_mode
(
cls
):
def
inference_mode
(
cls
):
"""A device-specific wrapper of `torch.inference_mode`.
"""A device-specific wrapper of `torch.inference_mode`.
...
...
vllm/platforms/rocm.py
View file @
217ee621
...
@@ -29,3 +29,8 @@ class RocmPlatform(Platform):
...
@@ -29,3 +29,8 @@ class RocmPlatform(Platform):
@
lru_cache
(
maxsize
=
8
)
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
torch
.
cuda
.
get_device_name
(
device_id
)
return
torch
.
cuda
.
get_device_name
(
device_id
)
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
device_props
=
torch
.
cuda
.
get_device_properties
(
device_id
)
return
device_props
.
total_memory
vllm/platforms/tpu.py
View file @
217ee621
...
@@ -10,6 +10,10 @@ class TpuPlatform(Platform):
...
@@ -10,6 +10,10 @@ class TpuPlatform(Platform):
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
raise
NotImplementedError
@
classmethod
@
classmethod
def
inference_mode
(
cls
):
def
inference_mode
(
cls
):
return
torch
.
no_grad
()
return
torch
.
no_grad
()
vllm/triton_utils/custom_cache_manager.py
View file @
217ee621
...
@@ -45,7 +45,7 @@ class CustomCacheManager(FileCacheManager):
...
@@ -45,7 +45,7 @@ class CustomCacheManager(FileCacheManager):
self
.
cache_dir
=
os
.
getenv
(
"TRITON_CACHE_DIR"
,
self
.
cache_dir
=
os
.
getenv
(
"TRITON_CACHE_DIR"
,
""
).
strip
()
or
default_cache_dir
()
""
).
strip
()
or
default_cache_dir
()
if
self
.
cache_dir
:
if
self
.
cache_dir
:
self
.
cache_dir
=
f
"
{
self
.
cache_dir
}
_
{
os
.
getpid
()
}
"
#
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
...
...
vllm/utils.py
View file @
217ee621
...
@@ -16,6 +16,7 @@ import threading
...
@@ -16,6 +16,7 @@ import threading
import
uuid
import
uuid
import
warnings
import
warnings
import
weakref
import
weakref
import
json
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
...
@@ -119,6 +120,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
...
@@ -119,6 +120,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
STR_INVALID_VAL
:
str
=
"INVALID"
GB_bytes
=
1_000_000_000
"""The number of bytes in one gigabyte (GB)."""
GiB_bytes
=
1
<<
30
GiB_bytes
=
1
<<
30
"""The number of bytes in one gibibyte (GiB)."""
"""The number of bytes in one gibibyte (GiB)."""
...
@@ -1331,3 +1335,88 @@ class AtomicCounter:
...
@@ -1331,3 +1335,88 @@ class AtomicCounter:
@
property
@
property
def
value
(
self
):
def
value
(
self
):
return
self
.
_value
return
self
.
_value
class
W8a8GetCacheJSON
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
W8a8GetCacheJSON
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
'./cache'
))
self
.
triton_json_dict
=
[]
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
return
configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
def
get_triton_cache_tune
(
self
,
file_path
,
n
,
k
):
#tuning的时候使用,当文件不存在时候,则创建文件夹
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
folder_path
=
os
.
path
.
dirname
(
file_path
)
os
.
makedirs
(
folder_path
,
exist_ok
=
True
)
cachedata
=
{}
# 写入空数据到新的JSON文件
with
open
(
file_path
,
'w'
)
as
file
:
json
.
dump
(
cachedata
,
file
)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'SPLIT_K'
:
int
(
sub_value
[
"SPLIT_K"
]),
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
])
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_triton_cache
(
self
,
file_path
,
n
,
k
):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'SPLIT_K'
:
int
(
sub_value
[
"SPLIT_K"
]),
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
])
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_w8a8json_name
(
self
,
n
,
k
):
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_DCU
{
device_name
}
.json"
\ No newline at end of file
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment