Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca4ec0ce
Commit
ca4ec0ce
authored
Mar 25, 2025
by
lizhigong
Browse files
Merge remote-tracking branch 'origin/v0.7.2-dev' into v0.7.2_zero_overhead
parents
0be169ad
ae0ed592
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
613 additions
and
22 deletions
+613
-22
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+553
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-4
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+34
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+6
-6
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+3
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+3
-1
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+3
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+4
-4
vllm/utils.py
vllm/utils.py
+0
-2
No files found.
vllm/model_executor/layers/quantization/utils/int8_utils.py
0 → 100755
View file @
ca4ec0ce
# SPDX-License-Identifier: Apache-2.0
import
functools
import
json
import
logging
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
# from sglang.srt.utils import get_device_name
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
@
triton
.
jit
def
_per_token_quant_int8
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
BLOCK
:
tl
.
constexpr
,
):
row_id
=
tl
.
program_id
(
0
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
tl
.
extra
.
cuda
.
libdevice
.
round
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
def
per_token_quant_int8
(
x
):
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
assert
x
.
is_contiguous
()
_per_token_quant_int8
[(
M
,)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
@
triton
.
jit
def
_per_token_group_quant_int8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for int8
int8_min
,
int8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
int8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_int8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
torch
.
int8
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_int8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
int8_min
=
int8_min
,
int8_max
=
int8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
@
triton
.
jit
def
_w8a8_block_int8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_int8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=int8_w8a8,block_shape=[
{
block_n
}
,
{
block_k
}
].json"
# noqa: E501
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for W8A8 Block INT8 kernel."
,
config_file_path
,
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path
,
)
return
None
def
w8a8_block_int8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
#configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
#if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
#else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
_w8a8_block_int8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
def
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
apply_w8a8_block_int8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
output
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
)
# output = native_w8a8_block_int8_matmul(
# q_input, weight, x_scale, weight_scale, block_size,
# output_dtype=input.dtype
# )
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
input_to_int8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
int8
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""This function quantizes input values to
int8 values with tensor-wise quantization.
"""
iinfo
=
torch
.
iinfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
int8_min
,
int8_max
=
iinfo
.
min
,
iinfo
.
max
scale
=
int8_max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
int8_min
,
max
=
int8_max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
def
block_dequant
(
x_q_block
:
torch
.
Tensor
,
x_s
:
torch
.
Tensor
,
block_size
:
List
[
int
],
)
->
torch
.
Tensor
:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n
,
k
=
x_q_block
.
shape
n_tiles
=
(
n
+
block_n
-
1
)
//
block_n
k_tiles
=
(
k
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
x_s
.
shape
[
0
]
assert
k_tiles
==
x_s
.
shape
[
1
]
x_dq_block
=
x_q_block
.
to
(
torch
.
float32
)
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
x_dq_block
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
n
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
k
),
]
*=
x_s
[
j
][
i
]
return
x_dq_block
vllm/model_executor/model_loader/utils.py
View file @
ca4ec0ce
...
...
@@ -80,7 +80,8 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
# 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration'
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
...
...
@@ -99,10 +100,11 @@ def get_model_architecture(
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
# awq相关配置
try
:
if
os
.
getenv
(
'AWQ_
PAD'
)
==
'0'
or
((
torch
.
cuda
.
isCurrentDeviceEco
(
torch
.
cuda
.
current_device
()))
and
os
.
getenv
(
'AWQ_PAD
'
)
==
None
)
:
os
.
environ
[
'AWQ_
PAD
'
]
=
'
0
'
else
:
if
os
.
getenv
(
'AWQ_
MOE_SZ
'
)
==
None
:
os
.
environ
[
'AWQ_
MOE_SZ
'
]
=
'
1
'
if
os
.
getenv
(
'AWQ_PAD'
)
==
None
and
(
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
)
:
os
.
environ
[
'AWQ_PAD'
]
=
'1'
except
Exception
as
e
:
if
os
.
getenv
(
'AWQ_PAD'
)
!=
'0'
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
ca4ec0ce
...
...
@@ -59,7 +59,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
block_dequant
as
int8_block_dequant
,
)
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
@@ -674,7 +676,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self
.
quant_method
=
quant_config
.
get_name
()
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -740,7 +742,26 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
dtype
=
dtype
,
device
=
device
),
})
def
restore_qzeros_tensor
(
self
,
qzeros
,
qscales
):
low_bits
=
qzeros
&
0x0F
high_bits
=
qzeros
>>
4
zeors_tensor
=
torch
.
stack
([
low_bits
,
high_bits
],
dim
=
2
).
view
(
qzeros
.
shape
[
0
],
-
1
,
qzeros
.
shape
[
-
1
])
zeors_int16
=
zeors_tensor
.
to
(
torch
.
int16
)
assert
zeors_int16
.
shape
==
qscales
.
shape
uint16_tensor1
=
zeors_int16
.
view
(
torch
.
uint16
)
uint16_tensor2
=
qscales
.
view
(
torch
.
uint16
)
uint32_tensor1
=
uint16_tensor1
.
to
(
torch
.
int32
)
<<
16
uint32_tensor2
=
uint16_tensor2
.
to
(
torch
.
int32
)
result_tensor
=
uint32_tensor1
+
uint32_tensor2
result_tensor
=
result_tensor
.
view
(
torch
.
uint32
)
result_tensor
=
result_tensor
.
transpose
(
1
,
2
).
contiguous
()
return
result_tensor
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
...
...
@@ -883,6 +904,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
"mlp.shared_experts.down_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
# moe_gather_sz
moe_key_words
=
[
"mlp.experts.w13_qweight"
,
"mlp.experts.w2_qweight"
]
moe_combined_words
=
"|"
.
join
(
moe_key_words
)
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
...
...
@@ -916,6 +940,14 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
if
self
.
use_w4a16_moe_sz
:
matches_moe
=
re
.
findall
(
moe_combined_words
,
layername
)
# sz.shape == s.shape.T
if
matches_moe
:
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
scales
.
data
=
sz_tensor
return
loaded_params
...
...
vllm/model_executor/models/gpt_neox.py
View file @
ca4ec0ce
...
...
@@ -370,7 +370,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
and
"scale"
not
in
layername
:
...
...
@@ -384,20 +385,19 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
#k=weight_data.shape[1]
#print("n:{},k:{}".format(n,k))
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
#("weight_shapes:",weight_shapes)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
...
...
vllm/model_executor/models/llama.py
View file @
ca4ec0ce
...
...
@@ -562,6 +562,7 @@ class LlamaModel(nn.Module):
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
...
...
@@ -579,7 +580,8 @@ class LlamaModel(nn.Module):
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
...
...
vllm/model_executor/models/qwen.py
View file @
ca4ec0ce
...
...
@@ -1185,6 +1185,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
...
...
@@ -1199,7 +1200,8 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
...
...
vllm/model_executor/models/qwen2.py
View file @
ca4ec0ce
...
...
@@ -539,6 +539,7 @@ class Qwen2Model(nn.Module):
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
...
...
@@ -553,7 +554,8 @@ class Qwen2Model(nn.Module):
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
...
...
vllm/platforms/rocm.py
View file @
ca4ec0ce
...
...
@@ -72,7 +72,7 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"moe_wna16"
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"moe_wna16"
,
"blockwise_int8"
]
@
classmethod
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
ca4ec0ce
...
...
@@ -84,10 +84,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config
=
copy
.
deepcopy
(
vllm_config
)
draft_worker_config
.
model_config
=
speculative_config
.
draft_model_config
draft_worker_config
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
draft_worker_config
.
model_config
,
vllm_config
.
load_config
,
)
#
draft_worker_config.quant_config = VllmConfig._get_quantization_config(
#
draft_worker_config.model_config,
#
vllm_config.load_config,
#
)
speculative_config
.
draft_parallel_config
.
worker_cls
=
\
draft_worker_config
.
parallel_config
.
sd_worker_cls
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
...
...
vllm/utils.py
View file @
ca4ec0ce
...
...
@@ -1534,8 +1534,6 @@ class W8a8GetCacheJSON:
def
_initialize
(
self
):
current_folder_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
json_folder_path
=
current_folder_path
+
'/../lmslim/configs/w8a8'
if
not
os
.
path
.
exists
(
json_folder_path
):
json_folder_path
=
current_folder_path
+
'/model_executor/layers/quantization/configs/w8a8'
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
json_folder_path
))
self
.
triton_json_dict
=
[]
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment