Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
27b78c73
Unverified
Commit
27b78c73
authored
Jan 29, 2025
by
Jinzhen Lin
Committed by
GitHub
Jan 29, 2025
Browse files
[Kernel] add triton fused moe kernel for gptq/awq (#12185)
parent
b02fd288
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
874 additions
and
55 deletions
+874
-55
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+91
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+354
-53
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+5
-2
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+424
-0
No files found.
tests/kernels/test_moe.py
View file @
27b78c73
...
...
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
...
...
@@ -55,6 +57,95 @@ def test_fused_moe(
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
32
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
if
weight_bits
==
4
:
pack_factor
=
2
quant_type
=
scalar_types
.
uint4
if
has_zp
else
scalar_types
.
uint4b8
elif
weight_bits
==
8
:
pack_factor
=
1
quant_type
=
scalar_types
.
uint8
if
has_zp
else
scalar_types
.
uint8b128
w1_ref
=
w1
.
clone
()
w2_ref
=
w2
.
clone
()
w1_qweight
=
torch
.
empty
((
e
,
2
*
n
,
k
//
pack_factor
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_qweight
=
torch
.
empty
((
e
,
k
,
n
//
pack_factor
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_scales
=
torch
.
empty
((
e
,
2
*
n
,
k
//
group_size
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_scales
=
torch
.
empty
((
e
,
k
,
n
//
group_size
),
device
=
"cuda"
,
dtype
=
dtype
)
w1_qzeros
=
torch
.
empty
((
e
,
2
*
n
//
pack_factor
,
k
//
group_size
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_qzeros
=
torch
.
empty
((
e
,
k
//
pack_factor
,
n
//
group_size
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
for
i
in
range
(
e
*
2
):
expert_id
=
i
%
e
if
i
//
e
==
0
:
w
,
w_ref
,
w_qweight
,
w_scales
,
w_qzeros
=
\
w1
,
w1_ref
,
w1_qweight
,
w1_scales
,
w1_qzeros
else
:
w
,
w_ref
,
w_qweight
,
w_scales
,
w_qzeros
=
\
w2
,
w2_ref
,
w2_qweight
,
w2_scales
,
w2_qzeros
weight
,
qweight
,
scales
,
qzeros
=
quantize_weights
(
w
[
expert_id
].
T
,
quant_type
,
group_size
,
has_zp
,
False
)
weight
=
weight
.
T
qweight
=
qweight
.
T
.
contiguous
().
to
(
torch
.
uint8
)
scales
=
scales
.
T
if
has_zp
:
qzeros
=
qzeros
.
T
.
contiguous
().
to
(
torch
.
uint8
)
if
weight_bits
==
4
:
qweight
=
qweight
[:,
1
::
2
]
*
16
+
qweight
[:,
::
2
]
if
has_zp
:
qzeros
=
qzeros
[
1
::
2
,
:]
*
16
+
qzeros
[::
2
,
:]
w_ref
[
expert_id
]
=
weight
w_qweight
[
expert_id
]
=
qweight
w_scales
[
expert_id
]
=
scales
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w2_qweight
,
score
,
topk
,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
w1_scales
,
w2_scale
=
w2_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
27b78c73
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/quantization/__init__.py
View file @
27b78c73
...
...
@@ -26,7 +26,8 @@ QUANTIZATION_METHODS: List[str] = [
"experts_int8"
,
"neuron_quant"
,
"ipex"
,
"quark"
"quark"
,
"moe_wna16"
]
# The customized quantization methods which will be added to this dict.
...
...
@@ -94,6 +95,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.ipex_quant
import
IPEXConfig
from
.marlin
import
MarlinConfig
from
.modelopt
import
ModelOptFp8Config
from
.moe_wna16
import
MoeWNA16Config
from
.neuron_quant
import
NeuronQuantConfig
from
.qqq
import
QQQConfig
from
.tpu_int8
import
Int8TpuConfig
...
...
@@ -121,7 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
"quark"
:
QuarkConfig
"quark"
:
QuarkConfig
,
"moe_wna16"
:
MoeWNA16Config
,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
0 → 100644
View file @
27b78c73
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.awq
import
(
AWQConfig
,
AWQLinearMethod
)
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinConfig
,
AWQMarlinLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.gptq
import
(
GPTQConfig
,
GPTQLinearMethod
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
,
GPTQMarlinLinearMethod
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
class
MoeWNA16Config
(
QuantizationConfig
):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
def
__init__
(
self
,
linear_quant_method
:
str
,
weight_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
lm_head_quantized
:
bool
,
modules_to_not_convert
:
Optional
[
List
[
str
]],
full_config
:
Dict
[
str
,
Any
])
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
bit8_pack_factor
=
8
//
self
.
weight_bits
self
.
lm_head_quantized
=
lm_head_quantized
self
.
linear_quant_method
=
linear_quant_method
self
.
full_config
=
full_config
self
.
use_marlin
=
False
if
self
.
linear_quant_method
==
"gptq"
:
self
.
use_marlin
=
GPTQMarlinConfig
.
is_gptq_marlin_compatible
(
full_config
)
elif
self
.
linear_quant_method
==
"awq"
:
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
awq_min_capability
=
AWQConfig
.
get_min_capability
()
if
device_capability
<
awq_min_capability
:
raise
ValueError
(
"The quantization method moe_wna16 + awq is not supported "
"for the current GPU. "
f
"Minimum capability:
{
awq_min_capability
}
. "
f
"Current capability:
{
device_capability
}
."
)
self
.
use_marlin
=
AWQMarlinConfig
.
is_awq_marlin_compatible
(
full_config
)
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
if
modules_to_not_convert
is
None
:
self
.
modules_to_not_convert
=
[]
else
:
self
.
modules_to_not_convert
=
modules_to_not_convert
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"moe_wna16"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MoeWNA16Config"
:
linear_quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
if
linear_quant_method
==
"gptq"
:
has_zp
=
not
cls
.
get_from_keys
(
config
,
[
"sym"
])
modules_to_not_convert
=
[]
elif
linear_quant_method
==
"awq"
:
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
modules_to_not_convert
=
cls
.
get_from_keys
(
config
,
[
"modules_to_not_convert"
])
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
return
cls
(
linear_quant_method
,
weight_bits
,
group_size
,
has_zp
,
lm_head_quantized
,
modules_to_not_convert
,
config
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_moe_wna16_compatible
(
hf_quant_cfg
)
if
can_convert
and
user_quant
==
"moe_wna16"
:
return
cls
.
get_name
()
return
None
@
classmethod
def
is_moe_wna16_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
awq_min_capability
=
AWQConfig
.
get_min_capability
()
gptq_compatible
=
quant_method
==
"gptq"
and
\
not
desc_act
and
num_bits
in
[
4
,
8
]
awq_compatible
=
quant_method
==
"awq"
and
num_bits
==
4
and
\
device_capability
>=
awq_min_capability
return
gptq_compatible
or
awq_compatible
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
is_layer_skipped_quant
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
MoeWNA16Method
(
self
)
else
:
if
self
.
linear_quant_method
==
"gptq"
:
if
self
.
use_marlin
:
return
GPTQMarlinLinearMethod
(
GPTQMarlinConfig
.
from_config
(
self
.
full_config
))
else
:
return
GPTQLinearMethod
(
GPTQConfig
.
from_config
(
self
.
full_config
))
elif
self
.
linear_quant_method
==
"awq"
:
if
self
.
use_marlin
:
return
AWQMarlinLinearMethod
(
AWQMarlinConfig
.
from_config
(
self
.
full_config
))
else
:
return
AWQLinearMethod
(
AWQConfig
.
from_config
(
self
.
full_config
))
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
def
is_layer_skipped_quant
(
prefix
:
str
,
modules_to_not_convert
:
List
[
str
]):
return
any
(
module_name
in
prefix
for
module_name
in
modules_to_not_convert
)
class
MoeWNA16Method
(
FusedMoEMethodBase
):
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
Args:
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
"""
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
layer
.
quant_config
=
self
.
quant_config
bit8_pack_factor
=
self
.
quant_config
.
bit8_pack_factor
group_size
=
self
.
quant_config
.
group_size
group_size_div_factor
=
1
# make intermediate_size and hidden_size diviable by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while
intermediate_size_per_partition
%
group_size
or
\
hidden_size
%
group_size
:
group_size
=
group_size
//
2
group_size_div_factor
*=
2
assert
group_size
>=
32
layer
.
group_size
=
group_size
layer
.
group_size_div_factor
=
group_size_div_factor
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
extra_weight_attrs
.
update
({
"quant_method"
:
strategy
,
"is_transposed"
:
False
})
assert
'weight_loader'
in
extra_weight_attrs
weight_loader
=
extra_weight_attrs
[
'weight_loader'
]
wrapped_weight_loader
=
MoeWNA16Method
.
get_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
'weight_loader'
]
=
wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
bit8_pack_factor
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
bit8_pack_factor
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
group_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
group_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
if
self
.
quant_config
.
has_zp
:
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
//
bit8_pack_factor
,
hidden_size
//
group_size
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
//
bit8_pack_factor
,
intermediate_size_per_partition
//
group_size
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
if
self
.
quant_config
.
linear_quant_method
==
"gptq"
:
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys
=
[
"w13_g_idx"
,
"w2_g_idx"
]
if
not
self
.
quant_config
.
has_zp
:
invalid_param_keys
+=
[
"w13_qzeros"
,
"w2_qzeros"
]
for
key
in
invalid_param_keys
:
param
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
0
,
),
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
key
,
param
)
set_weight_attrs
(
param
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
return
fused_experts
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
])
@
staticmethod
def
get_weight_loader
(
layer
,
weight_loader
):
def
convert_awq_tensor
(
tensor
,
tensor_type
):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0
=
tensor
.
size
(
0
)
tensor
=
tensor
.
view
(
torch
.
uint8
)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter
=
torch
.
tensor
([
0
,
4
],
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
tensor
=
(
tensor
[:,
:,
None
]
>>
shifter
)
&
0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order
=
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]
tensor
=
tensor
.
view
(
-
1
,
8
)[:,
reverse_awq_pack_order
]
tensor
=
tensor
.
view
(
size0
,
-
1
)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor
=
tensor
.
T
.
contiguous
()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if
tensor_type
==
"qweight"
:
tensor
=
tensor
[:,
1
::
2
]
*
16
+
tensor
[:,
::
2
]
elif
tensor_type
==
"qzeros"
:
tensor
=
tensor
[
1
::
2
,
:]
*
16
+
tensor
[::
2
,
:]
return
tensor
def
convert_gptq_int4_qzeros
(
tensor
):
tensor
=
tensor
.
view
(
torch
.
uint8
)
shifter
=
torch
.
tensor
([
0
,
4
],
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
tensor
=
(
tensor
[:,
:,
None
]
>>
shifter
)
&
0xF
tensor
=
tensor
+
1
tensor
=
tensor
[:,
:,
0
]
+
tensor
[:,
:,
1
]
*
16
return
tensor
def
moe_wna16_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
):
if
"g_idx"
in
weight_name
:
return
if
not
layer
.
quant_config
.
has_zp
and
"qzeros"
in
weight_name
:
return
device
=
get_tp_group
().
device
tp_rank
=
get_tensor_model_parallel_rank
()
loaded_weight
=
loaded_weight
.
to
(
device
)
shard_size
=
layer
.
intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if
layer
.
quant_config
.
linear_quant_method
==
"awq"
:
assert
layer
.
quant_config
.
weight_bits
==
4
if
"weight"
in
weight_name
:
loaded_weight
=
convert_awq_tensor
(
loaded_weight
,
"qweight"
)
elif
"zeros"
in
weight_name
:
loaded_weight
=
convert_awq_tensor
(
loaded_weight
,
"qzeros"
)
else
:
loaded_weight
=
loaded_weight
.
T
elif
layer
.
quant_config
.
linear_quant_method
==
"gptq"
:
assert
layer
.
quant_config
.
weight_bits
in
[
4
,
8
]
if
"weight"
in
weight_name
:
loaded_weight
=
loaded_weight
.
T
.
contiguous
().
view
(
torch
.
uint8
)
elif
"zeros"
in
weight_name
:
# add 1 to gptq qzeros to align with awq
loaded_weight
=
loaded_weight
.
view
(
torch
.
uint8
)
if
layer
.
quant_config
.
weight_bits
==
4
:
loaded_weight
=
convert_gptq_int4_qzeros
(
loaded_weight
).
T
else
:
loaded_weight
=
loaded_weight
.
T
+
1
else
:
loaded_weight
=
loaded_weight
.
T
# repeat the qzeros/scales to fit new group size
if
layer
.
group_size_div_factor
>
1
and
\
"qzeros"
in
weight_name
or
"scales"
in
weight_name
:
loaded_weight
=
loaded_weight
.
repeat_interleave
(
layer
.
group_size_div_factor
,
1
)
if
"w13_qzeros"
in
weight_name
:
tensor
=
loaded_weight
.
view
(
layer
.
tp_size
,
-
1
,
loaded_weight
.
size
(
1
))[
tp_rank
]
if
shard_id
==
"w1"
:
param
.
data
[
expert_id
,
:
shard_size
//
2
]
=
tensor
else
:
param
.
data
[
expert_id
,
shard_size
//
2
:]
=
tensor
elif
"w2_qzeros"
in
weight_name
:
param
.
data
[
expert_id
]
=
loaded_weight
.
view
(
loaded_weight
.
size
(
0
),
layer
.
tp_size
,
-
1
)[:,
tp_rank
]
else
:
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
moe_wna16_weight_loader
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment