Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1c77f16e
Commit
1c77f16e
authored
Dec 03, 2024
by
gaoqiong
Browse files
增加w8a8的triton调度支持
parent
6ebda263
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
176 additions
and
17 deletions
+176
-17
README.md
README.md
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-3
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+28
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+32
-5
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+29
-5
vllm/utils.py
vllm/utils.py
+84
-0
No files found.
README.md
View file @
1c77f16e
...
@@ -9,7 +9,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
...
@@ -9,7 +9,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表
## 支持模型结构列表
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ |
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ |
| :------: | :------: | :------: | :------: |
| :------: | :------: | :------: | :------: |
:------: |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,deepseek | Yes | Yes | Yes |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,deepseek | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | Yes |
...
...
vllm/_custom_ops.py
View file @
1c77f16e
...
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
...
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq
or w8a8
model.
\n
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -706,8 +706,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -706,8 +706,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
# return out
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
return
quant_ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
1c77f16e
...
@@ -4,12 +4,12 @@ import torch
...
@@ -4,12 +4,12 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
# Input scaling factors are no longer optional in _scaled_mm starting
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
cutlass_fp8_supported
()
->
bool
:
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
# cutlass is not supported on Rocm
...
@@ -200,12 +200,37 @@ def apply_int8_linear(
...
@@ -200,12 +200,37 @@ def apply_int8_linear(
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
if
w8a8_strategy
==
1
:
if
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
weight
.
shape
[
1
]
if
f
"
{
m
}
_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m
}
_
{
n
}
_
{
k
}
"
]
#print("json files:",best_config)
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
if
m
<
64
:
m_
=
32
elif
m
<
128
:
m_
=
64
elif
m
<
256
:
m_
=
128
elif
m
<
512
:
m_
=
256
elif
m
<
1024
:
m_
=
512
else
:
m_
=
1024
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
print
(
"config not found!"
)
return
ops
.
triton_scaled_mm
(
x_q
,
return
ops
.
triton_scaled_mm
(
x_q
,
weight
,
weight
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
bias
=
bias
,
best_config
=
best_config
)
elif
w8a8_strategy
==
2
:
elif
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
weight
,
...
...
vllm/model_executor/models/llama.py
View file @
1c77f16e
...
@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quant_config
,
quant_config
,
lora_config
=
lora_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
prefix
=
"model"
)
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
forward
(
def
forward
(
self
,
self
,
...
@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
#当为triton支持推理的时候不能进行处理
if
self
.
quant_method
==
"compressed_tensors"
:
if
self
.
quant_method
==
"compressed_tensors"
:
lay_key_words
=
[
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
"self_attn.qkv_proj.weight"
,
...
@@ -656,14 +660,37 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -656,14 +660,37 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"mlp.down_proj.weight"
,
"mlp.down_proj.weight"
,
]
]
combined_words
=
"|"
.
join
(
lay_key_words
)
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
n
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
# If this function is called, it should always initialize KV cache scale
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# factors (or else raise an exception). Thus, handled exceptions should
...
...
vllm/model_executor/models/qwen.py
View file @
1c77f16e
...
@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
,
W8a8GetCacheJSON
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
...
@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
_get_image_input_type
(
def
_get_image_input_type
(
self
,
self
,
...
@@ -1100,11 +1101,34 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
...
@@ -1100,11 +1101,34 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
"mlp.c_proj.weight"
,
"mlp.c_proj.weight"
,
]
]
combined_words
=
"|"
.
join
(
lay_key_words
)
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
n
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
vllm/utils.py
View file @
1c77f16e
...
@@ -16,6 +16,7 @@ import threading
...
@@ -16,6 +16,7 @@ import threading
import
uuid
import
uuid
import
warnings
import
warnings
import
weakref
import
weakref
import
json
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
...
@@ -1331,3 +1332,86 @@ class AtomicCounter:
...
@@ -1331,3 +1332,86 @@ class AtomicCounter:
@
property
@
property
def
value
(
self
):
def
value
(
self
):
return
self
.
_value
return
self
.
_value
class
W8a8GetCacheJSON
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
W8a8GetCacheJSON
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
'./cache'
))
self
.
triton_json_dict
=
[]
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
return
configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
def
get_triton_cache_tune
(
self
,
file_path
,
n
,
k
):
#tuning的时候使用,当文件不存在时候,则创建文件夹
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
folder_path
=
os
.
path
.
dirname
(
file_path
)
os
.
makedirs
(
folder_path
,
exist_ok
=
True
)
cachedata
=
{}
# 写入空数据到新的JSON文件
with
open
(
file_path
,
'w'
)
as
file
:
json
.
dump
(
cachedata
,
file
)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'SPLIT_K'
:
int
(
sub_value
[
"SPLIT_K"
]),
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
])
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_triton_cache
(
self
,
file_path
,
n
,
k
):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'SPLIT_K'
:
int
(
sub_value
[
"SPLIT_K"
]),
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
])
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_w8a8json_name
(
self
,
n
,
k
):
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_DCUK100AI.json"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment