Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53250530
"vllm/executor/mp_distributed_executor.py" did not exist on "eb6d3c264d0cd8e44dec16bca7947fbe96415ce9"
Commit
53250530
authored
Jun 05, 2025
by
gaoqiong
Browse files
Update w8a8_int8.py
parent
40b94473
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
371 additions
and
371 deletions
+371
-371
vllm/model_executor/layers/quantization/w8a8_int8.py
vllm/model_executor/layers/quantization/w8a8_int8.py
+371
-371
No files found.
vllm/model_executor/layers/quantization/w8a8_int8.py
View file @
53250530
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
,
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_quant_int8
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
import
os
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
return
output
.
to
(
out_dtype
)
class
W8A8Int8Config
(
QuantizationConfig
):
class
W8A8Int8Config
(
QuantizationConfig
):
"""Config class for W8A8 Int8 Quantization.
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
- Activation: dynamic, per-token, symmetric
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
75
return
75
@
classmethod
@
classmethod
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"w8a8_int8"
return
"w8a8_int8"
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
return
[]
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Int8Config"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Int8Config"
:
return
cls
()
return
cls
()
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
W8A8Int8LinearMethod
(
self
)
return
W8A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
W8A8Int8MoEMethod
(
self
)
return
W8A8Int8MoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
class
W8A8Int8LinearMethod
(
LinearMethodBase
):
class
W8A8Int8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
W8A8Int8Config
):
def
__init__
(
self
,
quantization_config
:
W8A8Int8Config
):
self
.
quantization_config
=
quantization_config
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
self
.
w8a8_strategy
==
1
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
else
:
else
:
weight_data
=
layer
.
weight
.
data
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
weight
.
data
=
_weight
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
),
input_dim
=
1
,
input_dim
=
1
,
output_dim
=
0
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
if
m
<=
16
:
m_
=
m
m_
=
m
elif
m
<=
64
:
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
elif
m
<
200
:
#256
m_
=
160
m_
=
160
elif
m
<
480
:
#512
elif
m
<
480
:
#512
m_
=
256
m_
=
256
elif
m
<
960
:
#1024
elif
m
<
960
:
#1024
m_
=
512
m_
=
512
elif
m
<
2048
:
elif
m
<
2048
:
m_
=
1024
m_
=
1024
elif
m
<
4096
:
elif
m
<
4096
:
m_
=
2048
m_
=
2048
elif
m
<
6000
:
elif
m
<
6000
:
m_
=
4096
m_
=
4096
else
:
else
:
m_
=
8192
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
else
:
best_config
=
None
best_config
=
None
if
best_config
==
None
:
#
if best_config==None:
print
(
"m:{},n:{},k:{}"
.
format
(
m
,
n
,
k
))
#
print("m:{},n:{},k:{}".format(m,n,k))
print
(
"config not found!"
)
#
print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
bias
=
bias
)
else
:
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
bias
=
bias
)
class
W8A8Int8MoEMethod
:
class
W8A8Int8MoEMethod
:
"""MoE method for INT8.
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
the model weights are loaded.
Args:
Args:
quant_config: The quantization config.
quant_config: The quantization config.
"""
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
hasattr
(
cls
,
"_initialized"
):
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
original_init
=
cls
.
__init__
new_cls
=
type
(
new_cls
=
type
(
cls
.
__name__
,
cls
.
__name__
,
(
FusedMoEMethodBase
,),
(
FusedMoEMethodBase
,),
{
{
"__init__"
:
original_init
,
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
},
)
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
obj
return
super
().
__new__
(
cls
)
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
int8
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
torch
.
int8
),
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
torch
.
int8
),
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
top_k
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
# Expert selection
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
top_k
=
top_k
,
renormalize
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
use_fused_gate
=
use_fused_gate
)
)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
use_int8_w8a8
=
True
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
activation
=
activation
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment