Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
40b94473
Commit
40b94473
authored
Jun 05, 2025
by
gaoqiong
Browse files
修改deepseelk block-int8 权重处理流程,增加per-channel bestconfig配置以及首次triton warmup代码
parent
a68aef25
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
57 deletions
+116
-57
vllm/model_executor/layers/quantization/blockwise_int8.py
vllm/model_executor/layers/quantization/blockwise_int8.py
+26
-0
vllm/model_executor/layers/quantization/w8a8_int8.py
vllm/model_executor/layers/quantization/w8a8_int8.py
+90
-13
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+0
-44
No files found.
vllm/model_executor/layers/quantization/blockwise_int8.py
View file @
40b94473
...
@@ -23,6 +23,10 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -23,6 +23,10 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
apply_w8a8_block_int8_linear
)
apply_w8a8_block_int8_linear
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
W8a8GetCacheJSON
import
os
from
vllm
import
_custom_ops
as
ops
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
@@ -128,6 +132,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
...
@@ -128,6 +132,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
BlockInt8Config
):
def
__init__
(
self
,
quant_config
:
BlockInt8Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
is_checkpoint_int8_serialized
assert
self
.
quant_config
.
is_checkpoint_int8_serialized
...
@@ -219,6 +224,27 @@ class BlockInt8LinearMethod(LinearMethodBase):
...
@@ -219,6 +224,27 @@ class BlockInt8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid cuda graph capturing issue
# Use torch Parameter to avoid cuda graph capturing issue
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
block_n
=
self
.
quant_config
.
weight_block_size
[
0
]
block_k
=
self
.
quant_config
.
weight_block_size
[
1
]
block_size
=
[
block_n
,
block_k
]
#print("layer.weight.device:",layer.weight.device)
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_blockint8json_name
(
n
,
k
,
block_n
,
block_k
)
configs_dict
=
self
.
tritonsingleton
.
get_blockint8_triton_cache
(
json_file
,
n
,
k
,
block_n
,
block_k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
#ops.triton_blockint8_gemm_helper(m=m,n=n,k=k,block_size=block_size,use_bias=False,out_dtype=torch.bfloat16,device=layer.weight.device,best_config=value)
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
...
...
vllm/model_executor/layers/quantization/w8a8_int8.py
View file @
40b94473
...
@@ -17,6 +17,12 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
...
@@ -17,6 +17,12 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8
,
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_quant_int8
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
import
os
from
vllm
import
_custom_ops
as
ops
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
@@ -84,8 +90,30 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -84,8 +90,30 @@ class W8A8Int8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
W8A8Int8Config
):
def
__init__
(
self
,
quantization_config
:
W8A8Int8Config
):
self
.
quantization_config
=
quantization_config
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
...
@@ -128,18 +156,67 @@ class W8A8Int8LinearMethod(LinearMethodBase):
...
@@ -128,18 +156,67 @@ class W8A8Int8LinearMethod(LinearMethodBase):
):
):
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
# return int8_scaled_mm(
if
self
.
w8a8_strategy
==
1
:
#
x_q
, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
m
=
x_q
.
shape
[
0
]
# )
k
=
x_q
.
shape
[
1
]
#return baseline_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, x.dtype, bias)
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
if
best_config
==
None
:
print
(
"m:{},n:{},k:{}"
.
format
(
m
,
n
,
k
))
print
(
"config not found!"
)
return
ops
.
triton_scaled_mm
(
x_q
,
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
W8A8Int8MoEMethod
:
class
W8A8Int8MoEMethod
:
"""MoE method for INT8.
"""MoE method for INT8.
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
40b94473
...
@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
W8a8GetCacheJSON
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
...
@@ -704,7 +703,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -704,7 +703,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -928,48 +926,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -928,48 +926,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
scales
.
data
=
sz_tensor
scales
.
data
=
sz_tensor
if
hasattr
(
self
.
config
,
"quantization_config"
)
and
self
.
config
.
quantization_config
[
"quant_method"
]
==
"blockwise_int8"
:
lay_key_words
=
[
"self_attn.q_a_proj.weight"
,
"self_attn.q_b_proj.weight"
,
"self_attn.kv_b_proj.weight"
,
"self_attn.kv_a_proj_with_mqa.weight"
,
"self_attn.o_proj.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj.weight"
,
"mlp.shared_experts.gate_up_proj.weight"
,
"mlp.shared_experts.down_proj.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
n
=
weight_data
.
shape
[
0
]
if
len
(
matched_key_words
)
<
9
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
#print("n:{},k:{}".format(n,k))
json_file
=
self
.
tritonsingleton
.
get_blockint8json_name
(
n
,
k
,
128
,
128
)
configs_dict
=
self
.
tritonsingleton
.
get_blockint8_triton_cache
(
json_file
,
n
,
k
,
128
,
128
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
self
.
tritonsingleton
.
triton_json_list
.
append
(
all_json
)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
# ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return
loaded_params
return
loaded_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment