Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d7db129a
Commit
d7db129a
authored
Feb 06, 2026
by
zhuwenwen
Browse files
修复awq/w4a16的triton支持,以及fuse_moe的接口对齐,以及awq_moe_marlin推理的相关bug,并解决awq/w4a16的精度问题
parent
bc387d5a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
356 additions
and
369 deletions
+356
-369
vllm/_custom_ops.py
vllm/_custom_ops.py
+5
-9
vllm/model_executor/layers/fused_moe/fused_marlin_moe_w4a16.py
...model_executor/layers/fused_moe/fused_marlin_moe_w4a16.py
+254
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+54
-248
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+8
-7
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+30
-53
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-0
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+0
-45
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+4
-4
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+0
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+0
-1
No files found.
vllm/_custom_ops.py
View file @
d7db129a
...
@@ -1511,15 +1511,11 @@ def awq_marlin_moe_repack(
...
@@ -1511,15 +1511,11 @@ def awq_marlin_moe_repack(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_experts
=
b_q_weight
.
shape
[
0
]
num_experts
=
b_q_weight
.
shape
[
0
]
assert
size_k
%
16
==
0
assert
size_k
%
16
==
0
output
=
torch
.
empty
(
output
=
torch
.
empty
((
num_experts
,
size_k
//
16
,
size_n
*
(
num_bits
//
2
)),
(
num_experts
,
size_k
//
16
,
size_n
*
(
num_bits
//
2
)),
device
=
b_q_weight
.
device
,
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
dtype
=
b_q_weight
.
dtype
,
output
[
e
]
=
op
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
)
size_n
,
num_bits
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
size_n
,
num_bits
,
is_a_8bit
)
return
output
return
output
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe_w4a16.py
0 → 100644
View file @
d7db129a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for GPTQ."""
import
functools
from
typing
import
Optional
import
torch
try
:
import
lightop
except
Exception
:
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
import
vllm.envs
as
envs
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_make_workspace_new
,
maybe_warn_marlin_atomic_add
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
direct_register_custom_op
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_moe_cache
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
return
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
else
:
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
def
fused_marlin_moe_int4
(
hidden_states
:
torch
.
Tensor
,
# 32, 7168
w1
:
torch
.
Tensor
,
# 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2
:
torch
.
Tensor
,
# 256, 256, 7168
w1_scale_zero
:
torch
.
Tensor
,
w2_scale_zero
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
4
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# quant_type = ScalarType.from_id(quant_type_id)
# assert quant_type in [
# scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
# scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f
# ]
# bit4_scalar_types = [
# scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f
# ]
# num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
(
num_bits
//
2
),
"Hidden size mismatch w2"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
# assert num_bits in [4]
assert
num_bits
in
[
4
]
num_tokens
,
K
=
hidden_states
.
shape
# 32, 7168
E
=
w1
.
shape
[
0
]
# 256
N
=
w2
.
shape
[
1
]
*
16
# 256
topk
=
topk_ids
.
shape
[
1
]
# 8
#暂时固定为16384
#CHUNK_SIZE = 16384
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
workspace
is
None
:
sms
=
torch
.
cuda
.
get_device_properties
(
device
=
'cuda'
).
multi_processor_count
workspace
=
torch
.
zeros
(
sms
*
3
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type1
=
get_scalar_type
(
num_bits
,
w1_zeros
is
not
None
)
scalar_type2
=
get_scalar_type
(
num_bits
,
w2_zeros
is
not
None
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk
,
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
intermediate_cache13
=
get_moe_cache
(
topk
,
N
,
K
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
else
:
intermediate_cache13
=
torch
.
empty
(
(
M
*
topk
*
max
(
2
*
N
,
K
),
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk
*
2
*
N
]
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk
*
K
]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
torch
.
cuda
.
get_device_capability
(
hidden_states
.
device
)[
0
]
>=
9
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
size
()
if
tokens_in_chunk
==
0
:
break
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
*
topk
,
:]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk
,
:]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
*
topk
,
:]
M
=
tokens_in_chunk
# Select block_size_m
for
block_size_m
in
[
16
,
32
,
48
,
64
,
80
]:
if
M
*
topk
/
E
/
block_size_m
<
0.9
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
block_size_m
,
global_num_experts
,
expert_map
)
intermediate_cache1
=
lightop
.
moe_marlin_w4a16
(
curr_hidden_states
,
intermediate_cache1
,
w1
,
w1_scale_zero
,
g_idx1
,
sort_indices1
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
curr_topk_weights
,
block_size_m
,
topk
,
False
,
expert_map
is
not
None
,
M
,
2
*
N
,
K
,
is_k_full
,
use_atomic_add
,
True
,
False
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
)
intermediate_cache3
=
lightop
.
moe_marlin_w4a16
(
intermediate_cache2
,
intermediate_cache3
,
w2
,
w2_scale_zero
,
g_idx2
,
sort_indices2
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
curr_topk_weights
,
block_size_m
,
1
,
True
,
expert_map
is
not
None
,
M
*
topk
,
K
,
N
,
is_k_full
,
use_atomic_add
,
True
,
False
).
view
(
-
1
,
topk
,
K
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
def
fused_marlin_moe_int4_fake
(
hidden_states
:
torch
.
Tensor
,
# 32, 7168
w1
:
torch
.
Tensor
,
# 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2
:
torch
.
Tensor
,
# 256, 256, 7168
w1_scale_zero
:
torch
.
Tensor
,
w2_scale_zero
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
4
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"fused_marlin_moe_int4"
,
op_func
=
fused_marlin_moe_int4
,
mutates_args
=
[],
fake_impl
=
fused_marlin_moe_int4_fake
,
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
d7db129a
...
@@ -109,162 +109,6 @@ def write_zeros_to_output(
...
@@ -109,162 +109,6 @@ def write_zeros_to_output(
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
fused_moe_kernel_awq
(
# Pointers to matrices
a_ptr
,
# [4, 7168]
b_ptr
,
# [256, 512, 3584]
c_ptr
,
# (8, 8, 512)
b_scale_ptr
,
# (256, 512, 56)
b_zp_ptr
,
# (256, 256, 56)
topk_weights_ptr
,
sorted_token_ids_ptr
,
# [0, 1, 2, 3, 4]
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
# pading后的总索引长度
num_valid_tokens
,
# 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
#1
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
#1
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# 128
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
# [block_m]
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# [block_n]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
# 0, 1, 2, ...... , 127 # # [block_k]
offs_k2
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
2
)
# 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
# [block_m, block_k]
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_bn
[:,
None
]
*
stride_bn
+
(
offs_k2
[
None
,
:])
*
stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
# 0, 4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# 0, 4
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
tl
.
interleave
(
b
,
b
)
b
=
b
.
trans
()
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsk
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsn
qzeros_scles
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
scales_int16
=
tl
.
cast
(
qzeros_scles
,
tl
.
uint16
)
b_scale
=
tl
.
cast
(
scales_int16
,
tl
.
float16
,
bitcast
=
True
)
# tl.device_print("b_scale dequant",b_scale)
mid
=
qzeros_scles
>>
16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp
=
tl
.
cast
(
mid
,
tl
.
float16
)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b
=
((
b
-
b_zp
)
*
b_scale
).
to
(
tl
.
float16
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel_gptq_awq
(
def
fused_moe_kernel_gptq_awq
(
...
@@ -861,98 +705,58 @@ def invoke_fused_moe_wna16_triton_kernel(
...
@@ -861,98 +705,58 @@ def invoke_fused_moe_wna16_triton_kernel(
triton
.
cdiv
(
EM
,
META
[
"BLOCK_SIZE_M"
])
triton
.
cdiv
(
EM
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
size
(
1
),
META
[
"BLOCK_SIZE_N"
]),
*
triton
.
cdiv
(
B
.
size
(
1
),
META
[
"BLOCK_SIZE_N"
]),
)
)
config
=
config
.
copy
()
# config = config.copy()
config
.
update
(
# config.update(
get_moe_wna16_block_config
(
# get_moe_wna16_block_config(
config
=
config
,
# config=config,
use_moe_wna16_cuda
=
False
,
# use_moe_wna16_cuda=False,
num_valid_tokens
=
num_tokens
,
# num_valid_tokens=num_tokens,
size_k
=
A
.
size
(
1
),
# size_k=A.size(1),
size_n
=
B
.
size
(
1
),
# size_n=B.size(1),
num_experts
=
B
.
size
(
1
),
# num_experts=B.size(1),
group_size
=
block_shape
[
1
],
# group_size=block_shape[1],
real_top_k
=
top_k
,
# real_top_k=top_k,
block_size_m
=
config
[
"BLOCK_SIZE_M"
],
# block_size_m=config["BLOCK_SIZE_M"],
)
# )
# )
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
size
(
1
),
A
.
size
(
1
),
EM
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
size
(
1
)
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
)
if
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
:
fused_moe_kernel_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
size
(
1
),
A
.
size
(
1
),
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
size
(
1
)
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
size
(
1
),
A
.
size
(
1
),
EM
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
size
(
1
)
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
def
invoke_fused_moe_triton_kernel
(
def
invoke_fused_moe_triton_kernel
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
...
@@ -1086,9 +890,11 @@ def dispatch_fused_moe_kernel(
...
@@ -1086,9 +890,11 @@ def dispatch_fused_moe_kernel(
use_int8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_channel_quant
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
list
[
int
]
|
None
=
None
,
block_shape
:
list
[
int
]
|
None
=
None
,
B_bias
:
torch
.
Tensor
|
None
=
None
,
B_bias
:
torch
.
Tensor
|
None
=
None
,
use_nn_moe
:
bool
|
None
=
False
,
)
->
None
:
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
...
@@ -2171,8 +1977,8 @@ def fused_experts_impl(
...
@@ -2171,8 +1977,8 @@ def fused_experts_impl(
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_M"
],
global_num_experts
,
global_num_experts
,
expert_map
,
expert_map
,
ignore_invalid_experts
=
True
,
ignore_invalid_experts
=
False
if
use_int4_w4a16
else
True
,
num_token
=
curr_hidden_states
.
shape
[
0
]
if
use_int4_w4a16
else
None
#
num_token=curr_hidden_states.shape[0] if use_int4_w4a16 else None
)
)
else
:
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
*
config
[
"BLOCK_SIZE_M"
]
max_num_tokens_padded
=
topk_ids
.
numel
()
*
config
[
"BLOCK_SIZE_M"
]
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
d7db129a
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
_w4a16
import
fused_marlin_moe
_int4
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoE
,
FusedMoEMethodBase
,
FusedMoEMethodBase
,
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
moe_awq_to_marlin_zero_points
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supported
,
verify_marlin_supports_shape
,
verify_marlin_supports_shape
,
awq_marlin_moe_permute_sz
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
...
@@ -786,18 +787,18 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -786,18 +787,18 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
use_nn_moe
:
bool
|
None
=
False
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
fused_marlin_moe
(
return
fused_marlin_moe
_int4
(
x
,
x
,
layer
.
w13_qweight
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
layer
.
w2_qweight
,
getattr
(
layer
,
"w13_bias"
,
None
),
#
getattr(layer, "w13_bias", None),
getattr
(
layer
,
"w2_bias"
,
None
),
#
getattr(layer, "w2_bias", None),
layer
.
w13_scales
,
layer
.
w13_scales
,
layer
.
w2_scales
,
layer
.
w2_scales
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
input_global_scale1
=
getattr
(
layer
,
"w13_input_global_scale"
,
None
),
#
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2
=
getattr
(
layer
,
"w2_input_global_scale"
,
None
),
#
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
# quant_type_id=self.quant_type.id,
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts
=
layer
.
global_num_experts
,
global_num_experts
=
layer
.
global_num_experts
,
...
@@ -805,6 +806,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -805,6 +806,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
w1_zeros
=
layer
.
w13_qzeros
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
input_dtype
=
self
.
input_dtype
,
#
input_dtype=self.input_dtype,
num_bits
=
4
,
num_bits
=
4
,
)
)
vllm/model_executor/layers/quantization/awq_triton.py
View file @
d7db129a
...
@@ -114,21 +114,10 @@ def awq_dequantize_kernel(
...
@@ -114,21 +114,10 @@ def awq_dequantize_kernel(
@
triton
.
jit
@
triton
.
jit
def
awq_gemm_kernel
(
def
awq_gemm_kernel
(
a_ptr
,
b_ptr
,
c_ptr
,
zeros_ptr
,
scales_ptr
,
M
,
N
,
K
,
a_ptr
,
group_size
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
b_ptr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
c_ptr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
):
zeros_ptr
,
scales_ptr
,
M
,
N
,
K
,
group_size
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
pid_z
=
tl
.
program_id
(
1
)
pid_z
=
tl
.
program_id
(
1
)
...
@@ -154,17 +143,18 @@ def awq_gemm_kernel(
...
@@ -154,17 +143,18 @@ def awq_gemm_kernel(
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# accumulator = accumulator & 0x0
# accumulator = accumulator & 0x0
# accumulator = accumulator.to(accumulator_dtype)
# accumulator = accumulator.to(accumulator_dtype)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
accumulator_dtype
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
# that will map given indices to the correct order.
reverse_awq_order_tensor
=
(
shifts
=
((
tl
.
arange
(
0
,
2
)
*
16
)[
None
,
:]
+
(
tl
.
arange
(
0
,
2
)
*
4
)[
None
,
:]
+
tl
.
arange
(
0
,
4
)[:,
None
]
(
tl
.
arange
(
0
,
4
)
*
4
)[:,
None
]).
reshape
(
1
,
8
)
).
reshape
(
8
)
# Create the necessary shifts to use to unpack.
# Create the necessary shifts to use to unpack.
shifts
=
reverse_awq_order_tensor
*
4
# shifts = reverse_awq_order_tensor * 4
shifts
=
tl
.
broadcast_to
(
shifts
[
None
,
:],
(
BLOCK_SIZE_K
*
(
BLOCK_SIZE_N
//
8
),
8
))
shifts
=
tl
.
broadcast_to
(
shifts
,
(
BLOCK_SIZE_K
*
(
BLOCK_SIZE_N
//
8
),
8
))
shifts
=
tl
.
reshape
(
shifts
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
shifts
=
tl
.
reshape
(
shifts
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
# Offsets and masks.
# Offsets and masks.
...
@@ -307,17 +297,12 @@ def awq_dequantize_triton(
...
@@ -307,17 +297,12 @@ def awq_dequantize_triton(
# qzeros - [K // G, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
# scales - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
# split_k_iters - parallelism along K-dimension, int, power of 2.
def
awq_gemm_triton
(
def
awq_gemm_triton
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
split_k_iters
:
int
,
config
=
None
)
->
torch
.
Tensor
:
block_size_m
:
int
=
32
,
block_size_n
:
int
=
32
,
block_size_k
:
int
=
32
,
config
=
None
,
)
->
torch
.
Tensor
:
M
,
K
=
input
.
shape
M
,
K
=
input
.
shape
N
=
qweight
.
shape
[
1
]
*
8
N
=
qweight
.
shape
[
1
]
*
8
group_size
=
qweight
.
shape
[
0
]
//
qzeros
.
shape
[
0
]
group_size
=
qweight
.
shape
[
0
]
//
qzeros
.
shape
[
0
]
...
@@ -332,8 +317,8 @@ def awq_gemm_triton(
...
@@ -332,8 +317,8 @@ def awq_gemm_triton(
assert
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
or
group_size
==
K
assert
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
or
group_size
==
K
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"
BLOCK_SIZE_M
"
])
*
triton
.
cdiv
(
N
,
META
[
"
BLOCK_SIZE_N
"
]),
triton
.
cdiv
(
M
,
META
[
'
BLOCK_SIZE_M
'
])
*
triton
.
cdiv
(
N
,
META
[
'
BLOCK_SIZE_N
'
]),
split_k_iters
,
META
[
'SPLIT_K'
]
,
)
)
if
config
is
None
:
if
config
is
None
:
...
@@ -342,27 +327,19 @@ def awq_gemm_triton(
...
@@ -342,27 +327,19 @@ def awq_gemm_triton(
#print("INFO:this size not found in json.")
#print("INFO:this size not found in json.")
config
=
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
8
,
'SPLIT_K'
:
1
}
config
=
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
8
,
'SPLIT_K'
:
1
}
result
=
torch
.
zeros
((
split_k_iters
,
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
result
=
torch
.
zeros
((
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
# A = input, B = qweight, C = result
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel
[
grid
](
awq_gemm_kernel
[
grid
](
input
,
input
,
qweight
,
qweight
,
result
,
result
,
qzeros
,
qzeros
,
scales
,
scales
,
M
,
M
,
N
,
N
,
K
,
K
,
group_size
,
group_size
,
**
config
)
BLOCK_SIZE_M
=
block_size_m
,
BLOCK_SIZE_N
=
block_size_n
,
BLOCK_SIZE_K
=
block_size_k
,
SPLIT_K
=
split_k_iters
,
**
config
,
)
result
=
result
.
sum
(
0
)
return
result
return
result
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
d7db129a
...
@@ -1853,6 +1853,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1853,6 +1853,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
**
_
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
d7db129a
...
@@ -232,12 +232,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -232,12 +232,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
,
moe
:
"FusedMoEConfig"
)
->
None
:
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
,
moe
:
"FusedMoEConfig"
)
->
None
:
super
().
__init__
(
moe
)
super
().
__init__
(
moe
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
use_w4a16_cuda
=
0
self
.
use_moe_lmslim
=
0
if
self
.
use_w4a16_moe_sz
:
self
.
use_w4a16_cuda
=
os
.
environ
[
'W4A16_MOE_CUDA'
]
==
'1'
self
.
use_moe_lmslim
=
os
.
environ
[
'W4A16_MOE_LMSLIM'
]
==
"1"
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -391,46 +385,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -391,46 +385,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
layer
.
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
layer
.
activation
==
"silu"
,
"Only SiLU activation is supported."
# TODO @yangql
# if self.use_moe_lmslim:
# return fused_experts_w4a16(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# inplace=True,
# activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
# use_int4_w4a16=True,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# w1_scale=layer.w13_scales,
# w2_scale=layer.w2_scales,
# block_shape=[0, layer.group_size])
# if self.use_w4a16_cuda:
# m = topk_ids.shape[0]
# if m <= 512:
# return fused_experts_cuda(x,
# layer.w13_qweight,
# layer.w2_qweight,
# topk_weights,
# topk_ids,
# inplace=True,
# use_fp8_w8a8=False,
# use_int4_w4a16=True,
# use_int8_w8a16=False,
# w1_scale=layer.w13_scales,
# w2_scale=layer.w2_scales,
# w1_zp=None,
# w2_zp=None,
# a1_scale=None,
# a2_scale=None,
# block_shape=[0, layer.group_size],
# expert_map=expert_map)
return
fused_experts
(
return
fused_experts
(
x
,
x
,
layer
.
w13_qweight
,
layer
.
w13_qweight
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
d7db129a
...
@@ -227,8 +227,8 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
...
@@ -227,8 +227,8 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
return
False
return
False
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
->
bool
:
def
check_moe_marlin_supports_layer
(
layer
:
LinearBase
,
group_size
:
int
)
->
bool
:
if
current_platform
.
is_rocm
():
#
if current_platform.is_rocm():
return
False
#
return False
hidden_size
=
layer
.
hidden_size
hidden_size
=
layer
.
hidden_size
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
# apply_router_weight_on_input is not supported for moe marlin
...
@@ -352,7 +352,7 @@ def marlin_permute_scales(
...
@@ -352,7 +352,7 @@ def marlin_permute_scales(
def
marlin_permute_bias
(
s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
marlin_permute_bias
(
s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
origin_shape
=
s
.
shape
origin_shape
=
s
.
shape
_
,
scale_perm_single
=
get_scale_perms
()
scale_perm_single
=
get_scale_perms
()
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
return
s
.
reshape
(
*
origin_shape
).
contiguous
()
return
s
.
reshape
(
*
origin_shape
).
contiguous
()
...
@@ -385,7 +385,7 @@ def marlin_zero_points(
...
@@ -385,7 +385,7 @@ def marlin_zero_points(
# Permute zero-points in a similar way to scales, but do not use the
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
# "single" permutation, since zero-points are applied on every MMA
# 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
# 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起
scale_perm
,
_
=
get_scale_perms
()
scale_perm
=
get_scale_perms
()
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
# uint4 混排
# uint4 混排
...
...
vllm/model_executor/model_loader/utils.py
View file @
d7db129a
...
@@ -255,8 +255,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
...
@@ -255,8 +255,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
# awq相关配置
# awq相关配置
try
:
try
:
# if os.getenv('AWQ_MOE_SZ') == None:
# os.environ['AWQ_MOE_SZ'] = '1'
if
os
.
getenv
(
'AWQ_PAD'
)
==
None
and
(
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
):
if
os
.
getenv
(
'AWQ_PAD'
)
==
None
and
(
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
):
os
.
environ
[
'AWQ_PAD'
]
=
'1'
os
.
environ
[
'AWQ_PAD'
]
=
'1'
except
Exception
as
e
:
except
Exception
as
e
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
d7db129a
...
@@ -1210,7 +1210,6 @@ class DeepseekV2ForCausalLM(
...
@@ -1210,7 +1210,6 @@ class DeepseekV2ForCausalLM(
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment