Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
081057de
Commit
081057de
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-ori
parents
7cf5d5c4
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1946 additions
and
340 deletions
+1946
-340
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+35
-8
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+173
-160
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+16
-11
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-10
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+346
-88
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+2
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+27
-2
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+3
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+12
-8
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+22
-13
vllm/model_executor/layers/quantization/bitblas.py
vllm/model_executor/layers/quantization/bitblas.py
+459
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+4
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+28
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+39
-3
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+7
-3
vllm/model_executor/layers/quantization/gptq_bitblas.py
vllm/model_executor/layers/quantization/gptq_bitblas.py
+438
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+23
-14
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
...r/layers/quantization/kernels/mixed_precision/__init__.py
+4
-1
vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
...or/layers/quantization/kernels/mixed_precision/bitblas.py
+299
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
...or/layers/quantization/kernels/mixed_precision/machete.py
+2
-5
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
081057de
...
@@ -15,7 +15,7 @@ def cutlass_moe_fp8(
...
@@ -15,7 +15,7 @@ def cutlass_moe_fp8(
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
_
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
...
@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
...
@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
out_dtype
:
torch
.
dtype
=
torch
.
half
,
out_dtype
:
torch
.
dtype
=
torch
.
half
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -57,12 +58,19 @@ def cutlass_moe_fp8(
...
@@ -57,12 +58,19 @@ def cutlass_moe_fp8(
quantize the intermediate result between the gemms.
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
- out_dtype (torch.Tensor): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i]
is -1, it means that this Rank is not responsible for global
expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1.
Returns:
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
"""
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
_
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
...
@@ -96,7 +104,13 @@ def cutlass_moe_fp8(
...
@@ -96,7 +104,13 @@ def cutlass_moe_fp8(
k
=
w1_q
.
size
(
1
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
local_topk_ids
=
topk_ids_
if
expert_map
is
not
None
:
"Translate info from expert_map to topk_ids"
local_topk_ids
=
torch
.
where
(
expert_map
[
topk_ids_
]
!=
-
1
,
expert_map
[
topk_ids_
],
-
1
)
topk
=
local_topk_ids
.
size
(
1
)
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
...
@@ -120,10 +134,23 @@ def cutlass_moe_fp8(
...
@@ -120,10 +134,23 @@ def cutlass_moe_fp8(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
a_map_initializer
=
torch
.
empty
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c2_initializer
=
torch
.
empty
if
expert_map
is
not
None
:
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
# With expert_map each Rank processes only a subset of experts. As
# a result not all of a_map and c2 tensors are filled. We fill it
# zeros for correctness.
a_map_initializer
=
torch
.
zeros
c2_initializer
=
torch
.
zeros
a_map
=
a_map_initializer
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
local_topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
local_topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
)
k
)
...
@@ -131,7 +158,7 @@ def cutlass_moe_fp8(
...
@@ -131,7 +158,7 @@ def cutlass_moe_fp8(
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
c2_initializer
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
081057de
...
@@ -5,17 +5,16 @@ from typing import Optional
...
@@ -5,17 +5,16 @@ from typing import Optional
import
torch
import
torch
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
if
has_zp
:
assert
num_bits
==
4
return
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
return
scalar_types
.
uint4
else
:
else
:
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
...
@@ -27,9 +26,12 @@ def single_marlin_moe(
...
@@ -27,9 +26,12 @@ def single_marlin_moe(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -62,7 +64,7 @@ def single_marlin_moe(
...
@@ -62,7 +64,7 @@ def single_marlin_moe(
assert
gating_output
.
shape
[
1
]
==
w
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
num_bits
in
[
4
,
8
]
assert
num_bits
in
[
4
,
8
]
M
,
K
=
hidden_states
.
shape
M
,
K
=
hidden_states
.
shape
...
@@ -83,39 +85,54 @@ def single_marlin_moe(
...
@@ -83,39 +85,54 @@ def single_marlin_moe(
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
max_workspace_size
=
(
N
//
64
)
*
16
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
workspace
=
torch
.
zeros
(
max_workspace_size
,
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
,
expert_map
)
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
,
if
workspace
is
None
:
requires_grad
=
False
)
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
\
(
sorted_token_ids
.
size
(
0
)
//
block_size_m
)
has_zero_point
=
w_zeros
is
not
None
device
=
hidden_states
.
device
if
w_zeros
is
None
:
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
w_zeros
=
torch
.
empty
((
0
,
0
),
max_workspace_size
=
min
(
max_workspace_size
,
sms
)
dtype
=
hidden_states
.
dtype
,
workspace
=
torch
.
zeros
(
max_workspace_size
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
device
=
device
,
requires_grad
=
False
)
if
g_idx
is
None
:
g_idx
=
torch
.
empty
((
0
,
0
),
scalar_type
=
get_scalar_type
(
num_bits
,
w_zeros
is
not
None
)
dtype
=
torch
.
int32
,
intermediate_cache
=
torch
.
empty
(
device
=
hidden_states
.
device
,
(
M
*
topk_ids
.
shape
[
1
],
N
),
requires_grad
=
False
)
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
if
sort_indices
is
None
:
)
sort_indices
=
torch
.
empty
((
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type
=
get_scalar_type
(
num_bits
,
has_zero_point
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
ops
.
moe_wna16_marlin_gemm
(
hidden_states
,
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
intermediate_cache
,
w_zeros
,
g_idx
,
sort_indices
,
workspace
,
scalar_type
.
id
,
M
,
N
,
K
,
w
,
is_k_full
,
E
,
topk
,
block_size_m
,
True
,
False
)
scales
,
w_zeros
,
g_idx
,
sort_indices
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
topk_weights
,
moe_block_size
=
block_size_m
,
top_k
=
topk
,
mul_topk_weights
=
False
,
is_ep
=
expert_map
is
not
None
,
b_q_type
=
scalar_type
,
size_m
=
M
,
size_n
=
N
,
size_k
=
K
,
is_k_full
=
is_k_full
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
)
intermediate_cache
=
intermediate_cache
.
view
(
-
1
,
topk
,
N
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
...
@@ -127,9 +144,12 @@ def single_marlin_moe_fake(
...
@@ -127,9 +144,12 @@ def single_marlin_moe_fake(
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -144,24 +164,26 @@ direct_register_custom_op(
...
@@ -144,24 +164,26 @@ direct_register_custom_op(
)
)
def
fused_marlin_moe
(
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
is_k_full
:
bool
=
True
,
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
weights, w1 and w2, and top-k gating mechanism.
...
@@ -196,27 +218,12 @@ def fused_marlin_moe(
...
@@ -196,27 +218,12 @@ def fused_marlin_moe(
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
(
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
(
num_bits
//
2
),
"Hidden size mismatch w2"
num_bits
//
2
),
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
assert
hidden_states
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
num_bits
in
[
4
,
8
]
assert
num_bits
in
[
4
,
8
]
has_no_act_order
=
(
g_idx1
is
None
and
g_idx2
is
None
and
sort_indices1
is
None
and
sort_indices2
is
None
)
has_all_act_order
=
(
g_idx1
is
not
None
and
g_idx2
is
not
None
and
sort_indices1
is
not
None
and
sort_indices2
is
not
None
)
assert
has_no_act_order
or
has_all_act_order
,
(
"g_idx and sorted_indices "
"must be all not None or must be all None"
)
has_no_zp
=
w1_zeros
is
None
and
w2_zeros
is
None
has_all_zp
=
w1_zeros
is
not
None
and
w2_zeros
is
not
None
assert
has_no_zp
or
has_all_zp
,
(
"zero points must be both not None or "
"must be both None"
)
M
,
K
=
hidden_states
.
shape
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
N
=
w2
.
shape
[
1
]
*
16
...
@@ -234,122 +241,128 @@ def fused_marlin_moe(
...
@@ -234,122 +241,128 @@ def fused_marlin_moe(
block_size_m
=
config
[
"BLOCK_SIZE_M"
]
block_size_m
=
config
[
"BLOCK_SIZE_M"
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
16
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
workspace
=
torch
.
zeros
(
max_workspace_size
,
moe_align_block_size
(
topk_ids
,
block_size_m
,
global_num_experts
,
dtype
=
torch
.
int
,
expert_map
)
device
=
current_platform
.
device_type
,
requires_grad
=
False
)
if
workspace
is
None
:
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
\
if
has_no_zp
:
(
sorted_token_ids
.
size
(
0
)
//
block_size_m
)
w1_zeros
=
torch
.
empty
((
0
,
0
),
device
=
hidden_states
.
device
dtype
=
hidden_states
.
dtype
,
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
device
=
hidden_states
.
device
,
max_workspace_size
=
min
(
max_workspace_size
,
sms
*
4
)
requires_grad
=
False
)
workspace
=
torch
.
zeros
(
max_workspace_size
,
w2_zeros
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int
,
dtype
=
hidden_states
.
dtype
,
device
=
device
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
requires_grad
=
False
)
scalar_type1
=
get_scalar_type
(
num_bits
,
w1_zeros
is
not
None
)
if
has_no_act_order
:
scalar_type2
=
get_scalar_type
(
num_bits
,
w2_zeros
is
not
None
)
g_idx1
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
g_idx2
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
sort_indices1
=
torch
.
empty
((
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
sort_indices2
=
torch
.
empty
((
0
,
0
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
requires_grad
=
False
)
scalar_type1
=
get_scalar_type
(
num_bits
,
has_all_zp
)
scalar_type2
=
get_scalar_type
(
num_bits
,
has_all_zp
)
intermediate_cache2
=
torch
.
empty
(
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
intermediate_cache13
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
]
*
max
(
2
*
N
,
K
),
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
K
]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
torch
.
cuda
.
get_device_capability
(
hidden_states
.
device
)[
0
]
>=
9
intermediate_cache1
=
torch
.
ops
.
_
moe_
C
.
marlin_gemm
_moe
(
intermediate_cache1
=
ops
.
moe_
wna16_
marlin_gemm
(
hidden_states
,
hidden_states
,
intermediate_cache1
,
w1
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
w1_scale
,
w1_zeros
,
w1_zeros
,
g_idx1
,
g_idx1
,
sort_indices1
,
sort_indices1
,
workspace
,
workspace
,
scalar_type1
.
id
,
sorted_token_ids
,
M
,
expert_ids
,
2
*
N
,
num_tokens_post_padded
,
K
,
topk_weights
,
is_k_full
,
moe_block_size
=
block_size_m
,
E
,
top_k
=
topk
,
topk
,
mul_topk_weights
=
False
,
block_size_m
,
is_ep
=
expert_map
is
not
None
,
True
,
b_q_type
=
scalar_type1
,
False
,
size_m
=
M
,
)
size_n
=
2
*
N
,
size_k
=
K
,
is_k_full
=
is_k_full
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
if
expert_map
is
not
None
:
intermediate_cache3
.
zero_
()
intermediate_cache3
=
ops
.
moe_wna16_marlin_gemm
(
intermediate_cache2
,
intermediate_cache2
,
intermediate_cache3
,
w2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
w2_scale
,
w2_zeros
,
w2_zeros
,
g_idx2
,
g_idx2
,
sort_indices2
,
sort_indices2
,
workspace
,
workspace
,
scalar_type2
.
id
,
sorted_token_ids
,
M
,
expert_ids
,
K
,
num_tokens_post_padded
,
N
,
topk_weights
,
is_k_full
,
moe_block_size
=
block_size_m
,
E
,
top_k
=
1
,
topk
,
mul_topk_weights
=
True
,
block_size_m
,
is_ep
=
expert_map
is
not
None
,
False
,
b_q_type
=
scalar_type2
,
True
,
size_m
=
M
*
topk
,
)
size_n
=
K
,
size_k
=
N
,
is_k_full
=
is_k_full
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
).
view
(
-
1
,
topk
,
K
)
output
=
hidden_states
if
inplace
else
torch
.
empty_like
(
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
dim
=
1
,
out
=
output
)
def
fused_marlin_moe_fake
(
hidden_states
:
torch
.
Tensor
,
def
fused_marlin_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
global_num_experts
:
int
=
-
1
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx1
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
g_idx2
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices1
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
sort_indices2
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
w1_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
is_k_full
:
bool
=
True
,
w2_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
workspace
:
Optional
[
torch
.
Tensor
]
=
None
,
num_bits
:
int
=
8
,
is_k_full
:
bool
=
True
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
081057de
...
@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
...
@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
from
.rocm_aiter_fused_moe
import
is_rocm_aiter_moe_enabled
rocm_aiter_fused_experts
,
rocm_aiter_topk_softmax
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -792,6 +790,18 @@ def get_default_config(
...
@@ -792,6 +790,18 @@ def get_default_config(
config
=
{
"BLOCK_SIZE_M"
:
32
,
"GROUP_SIZE_M"
:
1
}
config
=
{
"BLOCK_SIZE_M"
:
32
,
"GROUP_SIZE_M"
:
1
}
else
:
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"GROUP_SIZE_M"
:
1
}
config
=
{
"BLOCK_SIZE_M"
:
64
,
"GROUP_SIZE_M"
:
1
}
elif
is_marlin
:
for
block_size_m
in
[
8
,
16
,
32
,
48
,
64
]:
if
M
*
topk
/
E
/
block_size_m
<
0.9
:
break
return
{
"BLOCK_SIZE_M"
:
block_size_m
}
elif
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
else
:
else
:
config
=
{
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
...
@@ -799,14 +809,7 @@ def get_default_config(
...
@@ -799,14 +809,7 @@ def get_default_config(
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
8
,
}
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
if
use_nn_moe
:
if
use_nn_moe
:
config
[
"num_ldmatrixes"
]
=
1
config
[
"num_ldmatrixes"
]
=
1
return
config
return
config
...
@@ -867,6 +870,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
...
@@ -867,6 +870,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def
dispatch_topk_func
()
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
def
dispatch_topk_func
()
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
is_rocm_aiter_moe_enabled
():
if
is_rocm_aiter_moe_enabled
():
from
.rocm_aiter_fused_moe
import
rocm_aiter_topk_softmax
return
rocm_aiter_topk_softmax
return
rocm_aiter_topk_softmax
return
vllm_topk_softmax
return
vllm_topk_softmax
...
@@ -1127,6 +1131,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
...
@@ -1127,6 +1131,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
is_rocm_aiter_moe_enabled
():
if
is_rocm_aiter_moe_enabled
():
from
.rocm_aiter_fused_moe
import
rocm_aiter_fused_experts
return
rocm_aiter_fused_experts
return
rocm_aiter_fused_experts
if
inplace
:
if
inplace
:
return
torch_vllm_inplace_fused_experts
return
torch_vllm_inplace_fused_experts
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
081057de
...
@@ -128,12 +128,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -128,12 +128,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
super
().
process_weights_after_loading
(
layer
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
# Padding the weight for better performance on ROCm
layer
.
w13_weight
.
data
),
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
requires_grad
=
False
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
)
# Lazy import to avoid importing triton.
# Lazy import to avoid importing triton.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
...
@@ -142,10 +139,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -142,10 +139,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
layer
.
w13_weight
.
data
=
shuffled_w13
requires_grad
=
False
)
layer
.
w2_weight
.
data
=
shuffled_w2
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
if
current_platform
.
is_cpu
():
if
current_platform
.
is_cpu
():
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
...
@@ -443,6 +438,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -443,6 +438,7 @@ class FusedMoE(torch.nn.Module):
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
# Note: here we guard against accessing the TP and DP groups when
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
# uninitialized (this happens when testing)
...
@@ -493,6 +489,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -493,6 +489,7 @@ class FusedMoE(torch.nn.Module):
self
.
global_num_experts
=
num_experts
self
.
global_num_experts
=
num_experts
assert
intermediate_size
%
self
.
tp_size
==
0
assert
intermediate_size
%
self
.
tp_size
==
0
self
.
hidden_size
=
hidden_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
renormalize
=
renormalize
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
functools
import
cache
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm
import
envs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
@
cache
def
is_rocm_aiter_moe_enabled
()
->
bool
:
def
is_rocm_aiter_moe_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
return
current_platform
.
is_rocm
()
\
and
envs
.
VLLM_ROCM_USE_AITER_MOE
\
and
envs
.
VLLM_ROCM_USE_AITER_MOE
\
and
envs
.
VLLM_ROCM_USE_AITER
\
and
envs
.
VLLM_ROCM_USE_AITER
def
is_rocm_aiter_block_scaled_moe_enabled
()
->
bool
:
def
rocm_aiter_asm_moe_tkw1_impl
(
return
is_rocm_aiter_moe_enabled
()
and
\
envs
.
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
def
rocm_aiter_fused_experts
(
*
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
activation_str
:
str
=
"silu"
)
->
torch
.
Tensor
:
from
aiter
import
ActivationType
from
aiter.fused_moe_bf16_asm
import
asm_moe_tkw1
activation
=
\
ActivationType
.
Gelu
if
activation_str
==
"gelu"
else
ActivationType
.
Silu
return
asm_moe_tkw1
(
hidden_states
,
w1
,
w2
,
topk_weight
,
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
fc1_smooth_scale
=
fc1_smooth_scale
,
fc2_smooth_scale
=
fc2_smooth_scale
,
a16
=
a16
,
per_tensor_quant_scale
=
per_tensor_quant_scale
,
expert_mask
=
expert_mask
,
activation
=
activation
)
def
rocm_aiter_asm_moe_tkw1_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
per_tensor_quant_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwagrs
# Ignore additional keyword arguments
activation_str
:
str
=
"silu"
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_ck_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
aiter
import
ck_moe
return
ck_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
)
import
aiter
as
rocm_aiter
def
rocm_aiter_ck_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_fmoe_fp8_blockscale_g1u1_impl
(
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
hidden_states_dtype
:
torch
.
dtype
,
expert_mask
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
,
block_shape
:
List
[
int
],
smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
from
aiter
import
fmoe_fp8_blockscale_g1u1
from
aiter.fused_moe_bf16_asm
import
moe_sorting_ck
topk
=
topk_ids
.
shape
[
1
]
model_dim
=
w1
.
shape
[
-
1
]
local_E
=
E
=
w1
.
shape
[
0
]
if
expert_mask
is
not
None
:
E
=
expert_mask
.
numel
()
(
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
out_asm
,
)
=
moe_sorting_ck
(
topk_ids
,
topk_weights
,
E
,
model_dim
,
hidden_states_dtype
,
expert_mask
=
expert_mask
)
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
topk
,
w1_scale
.
view
(
local_E
,
-
1
),
w2_scale
.
view
(
local_E
,
-
1
),
a1_scale
.
t
().
contiguous
(),
*
block_shape
,
smooth_scale
)
return
out_asm
def
rocm_aiter_fmoe_fp8_blockscale_g1u1_fake
(
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
hidden_states_dtype
:
torch
.
dtype
,
expert_mask
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
,
block_shape
:
List
[
int
],
smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
a1
,
dtype
=
torch
.
bf16
)
def
rocm_aiter_asm_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
activation
:
str
=
"silu"
)
->
torch
.
Tensor
:
import
aiter.fused_moe_bf16_asm
as
rocm_aiter_asm_fmoe
import
aiter.fused_moe_bf16_asm
as
rocm_aiter_asm_fmoe
from
aiter
import
ActivationType
assert
activation
in
[
"silu"
,
"gelu"
],
"The given activation:"
\
f
"
{
activation
}
"
\
" is not supported in"
\
" AITER."
if
activation
==
"silu"
:
aiter_activation
=
ActivationType
.
Silu
else
:
aiter_activation
=
ActivationType
.
Gelu
return
rocm_aiter_asm_fmoe
.
asm_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weight
=
topk_weight
,
topk_ids
=
topk_ids
,
fc1_scale
=
fc1_scale
,
fc2_scale
=
fc2_scale
,
fc1_smooth_scale
=
fc1_smooth_scale
,
fc2_smooth_scale
=
fc2_smooth_scale
,
a16
=
a16
,
activation
=
aiter_activation
)
def
rocm_aiter_asm_moe_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
fc1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc1_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
fc2_smooth_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a16
:
bool
=
False
,
activation
:
str
=
"silu"
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
def
rocm_aiter_topk_softmax_impl
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
None
:
from
aiter
import
topk_softmax
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
)
def
rocm_aiter_topk_softmax_fake
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
None
:
pass
if
current_platform
.
is_rocm
():
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe_tkw1"
,
op_func
=
rocm_aiter_asm_moe_tkw1_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_asm_moe_tkw1_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_ck_moe"
,
op_func
=
rocm_aiter_ck_moe_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_ck_moe_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_fmoe_fp8_blockscale_g1u1"
,
op_func
=
rocm_aiter_fmoe_fp8_blockscale_g1u1_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_fmoe_fp8_blockscale_g1u1_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe"
,
op_func
=
rocm_aiter_asm_moe_impl
,
mutates_args
=
[],
fake_impl
=
rocm_aiter_asm_moe_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_topk_softmax"
,
op_func
=
rocm_aiter_topk_softmax_impl
,
mutates_args
=
[
"topk_weights"
,
"topk_indices"
,
"token_expert_indices"
],
fake_impl
=
rocm_aiter_topk_softmax_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
def
rocm_aiter_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
per_token_group_quant_fp8
)
if
envs
.
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
and
use_fp8_w8a8
:
# All AITER Fused MoE kernels are expecting the following datatypes
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
# w8a8 block-scaled
if
block_shape
is
not
None
and
use_fp8_w8a8
:
assert
not
apply_router_weight_on_input
,
(
"apply_router_weight_on_input is not supported for block scaled moe"
)
assert
w1_scale
is
not
None
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
assert
w2_scale
is
not
None
local_E
=
E
=
w1
.
shape
[
0
]
if
expert_mask
is
not
None
:
E
=
expert_mask
.
numel
()
topk
=
topk_ids
.
shape
[
1
]
model_dim
=
w1
.
shape
[
-
1
]
dtype
=
hidden_states
.
dtype
# The default block sizes are 128 in AITER.
# The default block sizes are 128 in AITER.
if
block_shape
is
None
:
block_shape
=
[
128
,
128
]
if
block_shape
is
None
else
block_shape
block_shape
=
[
128
,
128
]
a1
,
a1_scale
=
per_token_group_quant_fp8
(
hidden_states
,
block_shape
[
1
])
scale_blk_k
=
block_shape
[
1
]
return
torch
.
ops
.
vllm
.
rocm_aiter_fmoe_fp8_blockscale_g1u1
(
(
topk_ids
,
topk_weights
,
hidden_states
.
dtype
,
expert_map
,
a1
,
w1
,
sorted_token_ids
,
w2
,
w1_scale
,
w2_scale
,
a1_scale
,
block_shape
,
None
)
sorted_weight_buf
,
sorted_expert_ids
,
# w8a8 per-channel quantization
num_valid_ids
,
elif
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
out_asm
,
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
)
=
rocm_aiter_asm_fmoe
.
moe_sorting_ck
(
topk_ids
,
# This applies topk_weights on the GEMM output of the first FC layer
topk_weights
,
# rather than the second FC.
E
,
assert
(
topk_weights
.
dim
()
==
2
model_dim
,
),
"`topk_weights` should be in shape (num_tokens, topk)"
dtype
,
assert
topk_weights
.
shape
[
-
1
]
==
1
,
(
expert_mask
=
expert_mask
)
"Only support topk=1 when"
" `apply_router_weight_on_input` is True"
)
a1
,
a1_scale
=
per_token_group_quant_fp8
(
hidden_states
,
scale_blk_k
)
rocm_aiter
.
fmoe_fp8_blockscale_g1u1
(
return
torch
.
ops
.
vllm
.
rocm_aiter_asm_moe_tkw1
(
out_asm
,
hidden_states
,
a1
,
w1
,
w1
,
w2
,
w2
,
sorted_token_ids
,
topk_weights
,
sorted_weight_buf
,
topk_ids
,
sorted_expert_ids
,
fc1_scale
=
w1_scale
,
num_valid_ids
,
fc2_scale
=
w2_scale
,
topk
,
fc1_smooth_scale
=
None
,
w1_scale
.
view
(
local_E
,
-
1
),
fc2_smooth_scale
=
None
,
w2_scale
.
view
(
local_E
,
-
1
),
a16
=
False
,
a1_scale
.
t
().
contiguous
(),
per_tensor_quant_scale
=
None
,
block_shape
[
0
],
expert_mask
=
expert_map
,
block_shape
[
1
],
activation_str
=
activation
)
None
,
)
# w8a8 per-tensor activation per-tensor weight
return
out_asm
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
:
return
rocm_aiter_asm_fmoe
.
asm_moe
(
hidden_states
=
hidden_states
,
assert
not
apply_router_weight_on_input
,
(
w1
=
w1
,
"apply_router_weight_on_input is not supported for fp8_w8a8"
)
w2
=
w2
,
return
torch
.
ops
.
vllm
.
rocm_aiter_asm_moe
(
hidden_states
=
hidden_states
,
topk_weight
=
topk_weights
,
w1
=
w1
,
topk_ids
=
topk_ids
,
w2
=
w2
,
fc1_scale
=
w1_scale
,
topk_weight
=
topk_weights
,
fc2_scale
=
w2_scale
,
topk_ids
=
topk_ids
,
fc1_smooth_scale
=
None
,
fc1_scale
=
w1_scale
,
fc2_smooth_scale
=
None
,
fc2_scale
=
w2_scale
,
a16
=
False
)
fc1_smooth_scale
=
None
,
fc2_smooth_scale
=
None
,
return
rocm_aiter
.
ck_moe
(
hidden_states
=
hidden_states
,
a16
=
False
,
w1
=
w1
,
activation
=
activation
)
w2
=
w2
,
if
apply_router_weight_on_input
:
topk_weights
=
topk_weights
,
assert
(
topk_weights
.
dim
()
==
2
topk_ids
=
topk_ids
)
),
"`topk_weights` should be in shape (num_tokens, topk)"
_
,
topk
=
topk_weights
.
shape
assert
(
topk
==
1
),
"Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states
=
hidden_states
*
topk_weights
.
to
(
hidden_states
.
dtype
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_weights
=
torch
.
ones_like
(
topk_weights
,
dtype
=
torch
.
float32
)
# w16a16 fallback to rocm_aiter_ck_moe w16a16
return
torch
.
ops
.
vllm
.
rocm_aiter_ck_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
)
def
rocm_aiter_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
def
rocm_aiter_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
...]:
renormalize
:
bool
)
->
Tuple
[
torch
.
Tensor
,
...]:
import
aiter
as
rocm_aiter
torch
.
ops
.
vllm
.
rocm_aiter_topk_softmax
(
topk_weights
,
topk_indices
,
rocm_aiter
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
token_expert_indices
,
gating_output
,
gating_output
,
renormalize
)
renormalize
)
return
topk_weights
,
topk_indices
return
topk_weights
,
topk_indices
def
shuffle_weights
(
*
tensors
:
torch
.
Tensor
)
->
t
uple
[
torch
.
Tensor
,
...]:
def
shuffle_weights
(
*
tensors
:
torch
.
Tensor
)
->
T
uple
[
torch
.
Tensor
,
...]:
"""
"""
Applies shuffle_weight function from AITER to each
Applies shuffle_weight function from AITER to each
input tensor and returns them.
input tensor and returns them.
...
@@ -129,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
...
@@ -129,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
*tensors: Variable number of torch.Tensor objects.
*tensors: Variable number of torch.Tensor objects.
Returns:
Returns:
A
t
uple of shuffled tensors.
A
T
uple of shuffled tensors.
"""
"""
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
return
tuple
(
shuffle_weight
(
tensor
)
for
tensor
in
tensors
)
return
tuple
(
shuffle_weight
(
tensor
)
for
tensor
in
tensors
)
def
expand_weights
(
*
tensors
:
torch
.
Tensor
,
def
expand_weights
(
*
tensors
:
torch
.
Tensor
,
expansion_dims
:
list
[
int
])
->
t
uple
[
torch
.
Tensor
,
...]:
expansion_dims
:
list
[
int
])
->
T
uple
[
torch
.
Tensor
,
...]:
"""
"""
Expands the dimensions of input tensors.
Expands the dimensions of input tensors.
...
@@ -147,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor,
...
@@ -147,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor,
corresponding to each tensor.
corresponding to each tensor.
Returns:
Returns:
A
t
uple of tensors with expanded dimensions.
A
T
uple of tensors with expanded dimensions.
"""
"""
assert
len
(
tensors
)
==
len
(
expansion_dims
),
\
assert
len
(
tensors
)
==
len
(
expansion_dims
),
\
...
...
vllm/model_executor/layers/layernorm.py
View file @
081057de
...
@@ -168,7 +168,8 @@ class RMSNorm(CustomOp):
...
@@ -168,7 +168,8 @@ class RMSNorm(CustomOp):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
from
vllm_hpu_extension.ops
import
HPUFusedRMSNorm
from
vllm_hpu_extension.kernels
import
rms_norm
HPUFusedRMSNorm
=
rms_norm
()
if
HPUFusedRMSNorm
is
None
:
if
HPUFusedRMSNorm
is
None
:
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
if
residual
is
not
None
:
if
residual
is
not
None
:
...
...
vllm/model_executor/layers/linear.py
View file @
081057de
...
@@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union
...
@@ -6,7 +6,6 @@ from typing import Any, Literal, Optional, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
@@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -17,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
# yapf: disable
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
...
@@ -31,6 +31,8 @@ logger = init_logger(__name__)
...
@@ -31,6 +31,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"BitBLASLinearMethod"
,
"GPTQBitBLASLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"GPTQMarlinLinearMethod"
,
...
@@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
]
]
def
adjust_bitblas_shard
(
param
,
shard_size
,
shard_offset
):
bitblas_tile_size
=
getattr
(
param
,
"bitblas_tile_size"
,
None
)
if
bitblas_tile_size
is
not
None
:
return
(
shard_size
//
bitblas_tile_size
,
shard_offset
//
bitblas_tile_size
)
return
shard_size
,
shard_offset
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
if
marlin_tile_size
is
None
:
...
@@ -188,7 +199,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -188,7 +199,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
return
dispatch_unquantized_gemm
()
(
x
,
layer
.
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
class
LinearBase
(
torch
.
nn
.
Module
):
...
@@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
shard_size
,
shard_offset
=
adjust_bitblas_shard
(
param
,
shard_size
,
shard_offset
)
if
use_bitsandbytes_4bit
:
if
use_bitsandbytes_4bit
:
index
=
list
(
itertools
.
accumulate
([
0
]
+
self
.
output_sizes
))
index
=
list
(
itertools
.
accumulate
([
0
]
+
self
.
output_sizes
))
orig_offsets
=
{
orig_offsets
=
{
...
@@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
shard_size
,
shard_offset
=
adjust_bitblas_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
False
)
...
@@ -913,6 +929,15 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -913,6 +929,15 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
# Note(simon): This is needed for Qwen3's fp8 quantization.
if
isinstance
(
param
,
BlockQuantScaleParameter
):
assert
self
.
quant_method
is
not
None
assert
hasattr
(
self
.
quant_method
,
"quant_config"
)
weight_block_size
=
self
.
quant_method
.
quant_config
.
weight_block_size
block_n
,
_
=
weight_block_size
[
0
],
weight_block_size
[
1
]
shard_offset
=
(
shard_offset
+
block_n
-
1
)
//
block_n
shard_size
=
(
shard_size
+
block_n
-
1
)
//
block_n
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_id
=
loaded_shard_id
,
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
081057de
...
@@ -10,8 +10,10 @@ from packaging import version
...
@@ -10,8 +10,10 @@ from packaging import version
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.triton_utils
import
HAS_TRITON
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
TRITON3
=
HAS_TRITON
and
(
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
))
if
TRITON3
:
if
TRITON3
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Type
from
typing
import
Literal
,
Type
,
get_args
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
Q
UANTIZATION_METHODS
:
List
[
str
]
=
[
Q
uantizationMethods
=
Literal
[
"aqlm"
,
"aqlm"
,
"awq"
,
"awq"
,
"deepspeedfp"
,
"deepspeedfp"
,
...
@@ -15,12 +15,12 @@ QUANTIZATION_METHODS: List[str] = [
...
@@ -15,12 +15,12 @@ QUANTIZATION_METHODS: List[str] = [
"fbgemm_fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
"modelopt"
,
"nvfp4"
,
"nvfp4"
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
,
"marlin"
,
"bitblas"
,
"gguf"
,
"gguf"
,
"gptq_marlin_24"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"gptq_marlin"
,
"gptq_bitblas"
,
"awq_marlin"
,
"awq_marlin"
,
"gptq"
,
"gptq"
,
"compressed-tensors"
,
"compressed-tensors"
,
...
@@ -34,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [
...
@@ -34,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [
"moe_wna16"
,
"moe_wna16"
,
"torchao"
,
"torchao"
,
]
]
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
# The customized quantization methods which will be added to this dict.
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
=
{}
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
=
{}
...
@@ -85,6 +86,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -85,6 +86,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.aqlm
import
AQLMConfig
from
.aqlm
import
AQLMConfig
from
.awq
import
AWQConfig
from
.awq
import
AWQConfig
from
.awq_marlin
import
AWQMarlinConfig
from
.awq_marlin
import
AWQMarlinConfig
from
.bitblas
import
BitBLASConfig
from
.bitsandbytes
import
BitsAndBytesConfig
from
.bitsandbytes
import
BitsAndBytesConfig
from
.compressed_tensors.compressed_tensors
import
(
# noqa: E501
from
.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
)
CompressedTensorsConfig
)
...
@@ -94,6 +96,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -94,6 +96,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.fp8
import
Fp8Config
from
.fp8
import
Fp8Config
from
.gguf
import
GGUFConfig
from
.gguf
import
GGUFConfig
from
.gptq
import
GPTQConfig
from
.gptq
import
GPTQConfig
from
.gptq_bitblas
import
GPTQBitBLASConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.gptq_marlin_24
import
GPTQMarlin24Config
from
.gptq_marlin_24
import
GPTQMarlin24Config
from
.hqq_marlin
import
HQQMarlinConfig
from
.hqq_marlin
import
HQQMarlinConfig
...
@@ -107,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -107,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.torchao
import
TorchAOConfig
from
.torchao
import
TorchAOConfig
from
.tpu_int8
import
Int8TpuConfig
from
.tpu_int8
import
Int8TpuConfig
method_to_config
:
D
ict
[
str
,
Type
[
QuantizationConfig
]]
=
{
method_to_config
:
d
ict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
...
@@ -116,12 +119,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -116,12 +119,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"nvfp4"
:
ModelOptNvFp4Config
,
"nvfp4"
:
ModelOptNvFp4Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
"bitblas"
:
BitBLASConfig
,
"gguf"
:
GGUFConfig
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_bitblas"
:
GPTQBitBLASConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
...
@@ -144,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -144,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
__all__
=
[
__all__
=
[
"QuantizationConfig"
,
"QuantizationConfig"
,
"QuantizationMethods"
,
"get_quantization_config"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
"QUANTIZATION_METHODS"
,
]
]
\ No newline at end of file
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
081057de
...
@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
...
@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq
)
is_layer_skipped_awq
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.moe_wna16
import
MoeWNA16Config
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
check_marlin_supports_layer
,
marlin_make_empty_g_idx
,
check_marlin_supports_layer
,
check_moe_marlin_supports_layer
,
marlin_make_
workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
marlin_make_
empty_g_idx
,
marlin_make_workspace
,
marlin_
moe_
permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
)
...
@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
...
@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMarlinLinearMethod
(
self
)
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
layer
.
local_num_experts
>
32
:
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
# For MoEs with many experts the moe_wna16 kernel is faster
MoeWNA16Config
)
if
not
check_moe_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_one
(
f
"Layer '
{
prefix
}
' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
return
MoeWNA16Config
.
from_config
(
return
MoeWNA16Config
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
else
:
return
AWQMoEMethod
(
self
)
return
AWQMoEMethod
(
self
)
return
None
return
None
@
classmethod
@
classmethod
...
@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
device
=
layer
.
w13_qweight
.
device
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
layer
.
workspace
=
torch
.
zeros
((
sms
*
4
,
),
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
num_experts
=
layer
.
w13_qweight
.
shape
[
0
]
num_experts
=
layer
.
w13_qweight
.
shape
[
0
]
device
=
layer
.
w13_qweight
.
device
device
=
layer
.
w13_qweight
.
device
...
@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
expert_map
is
not
None
:
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
if
apply_router_weight_on_input
:
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"Apply router weight on input is not supported for"
...
@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
...
@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_zeros
=
layer
.
w13_qzeros
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
)
vllm/model_executor/layers/quantization/bitblas.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
BITBLAS_OPTIMIZE_FEATURES
,
BITBLAS_SUPPORTED_NUM_BITS
,
BITBLAS_SUPPORTED_SYM
,
MINIMUM_BITBLAS_VERSION
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
BitBLASConfig
(
QuantizationConfig
):
"""Config class for BitBLAS.
Reference: https://github.com/Microsoft/BitBLAS
"""
TORCH_DTYPE
=
torch
.
float16
STORAGE_DTYPE
=
"int8"
# assume int8 storage
TORCH_STORAGE_DTYPE
=
getattr
(
torch
,
STORAGE_DTYPE
)
# "original" or "rescale" or "quantized",
# gptq_with_bitblas prefer "quantized implementation"
ZEROS_MODE
=
"quantized"
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
Optional
[
int
],
desc_act
:
Optional
[
bool
],
is_sym
:
Optional
[
bool
],
quant_method
:
Optional
[
str
],
lm_head_quantized
:
bool
,
)
->
None
:
try
:
import
bitblas
if
bitblas
.
__version__
<
MINIMUM_BITBLAS_VERSION
:
raise
ImportError
(
"bitblas version is wrong. Please "
f
"install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
"
)
except
ImportError
as
e
:
bitblas_import_exception
=
e
raise
ValueError
(
"Trying to use the bitblas backend, but could not import"
f
"with the following error:
{
bitblas_import_exception
}
. "
"Please install bitblas through the following command: "
f
"`pip install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
`"
)
from
bitblas_import_exception
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
quant_method
=
quant_method
self
.
lm_head_quantized
=
lm_head_quantized
# Verify
if
self
.
weight_bits
not
in
BITBLAS_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"BitBLAS does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
BITBLAS_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
is_sym
not
in
BITBLAS_SUPPORTED_SYM
:
raise
ValueError
(
f
"BitBLAS does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
BITBLAS_SUPPORTED_SYM
}
are supported."
)
storage_dtype
=
self
.
STORAGE_DTYPE
storage_nbit
=
int
(
""
.
join
(
c
for
c
in
storage_dtype
if
c
.
isdigit
()))
self
.
storage_dtype
=
storage_dtype
self
.
storage_torch_dtype
=
self
.
TORCH_STORAGE_DTYPE
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
storage_nbit
//
weight_bits
self
.
nbits
=
weight_bits
# Zeros type for the quantized weights.
self
.
zeros_mode
=
self
.
ZEROS_MODE
def
__repr__
(
self
)
->
str
:
return
(
f
"BitBLASConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"is_sym=
{
self
.
is_sym
}
, "
f
"quant_method=
{
self
.
quant_method
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"bitblas"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
70
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
staticmethod
def
get_from_keys
(
config
:
Dict
[
str
,
Any
],
keys
:
List
[
str
],
default
:
Any
=
None
)
->
Any
:
"""Get a value from the model's quantization config."""
for
key
in
keys
:
if
key
in
config
:
return
config
[
key
]
return
default
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitBLASConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
],
-
1
)
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
],
False
)
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
],
False
)
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
quant_method
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format
=
(
hf_quant_cfg
.
get
(
"checkpoint_format"
)
==
"bitblas"
or
hf_quant_cfg
.
get
(
"is_bitblas_format"
,
False
))
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"gptq"
or
user_quant
==
"bitblas"
)
if
is_bitblas_format
and
is_valid_user_quant
:
msg
=
(
"The model is serialized in {} format. Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"BitBLASLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
BitBLASLinearMethod
(
self
)
return
None
class
BitBLASLinearMethod
(
LinearMethodBase
):
"""Linear method for BitBLAS.
Args:
quant_config: The BitBLAS quantization config.
"""
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
# Instead of BITBLAS_OPTIMIZE_FEATURES
# If you want to high contiguous batching
# performance
OPT_FEATURES
=
BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING
=
True
BITBLAS_DTYPES
=
{
torch
.
float32
:
"float32"
,
torch
.
float16
:
"float16"
,
torch
.
bfloat16
:
"bfloat16"
,
torch
.
half
:
"float16"
,
torch
.
int8
:
"int8"
,
}
def
__init__
(
self
,
quant_config
:
BitBLASConfig
):
self
.
quant_config
=
quant_config
def
create_weights_gptq
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized
weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_size_per_partition: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the
input size per partition is not divisible by the group size in
`quant_config`.
"""
del
input_size
,
output_size
# Unused arguments.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
not
in
self
.
quant_config
.
get_supported_act_dtypes
():
raise
ValueError
(
"Parameter data type must be torch.float16, "
f
"but got
{
params_dtype
}
"
)
group_size
=
self
.
quant_config
.
group_size
if
group_size
is
None
:
group_size
=
-
1
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
(
group_size
!=
-
1
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Input size per partition (
{
input_size_per_partition
}
) must "
f
"be divisible by group size (
{
group_size
}
)."
)
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self
.
_configure_bitblas_matmul
(
input_size_per_partition
,
output_size_per_partition
,
params_dtype
=
params_dtype
,
enable_tuning
=
self
.
ENABLE_TUNING
,
bias
=
False
,
layout
=
"nt"
,
bits
=
self
.
quant_config
.
weight_bits
,
)
# Initialize quantized weights with dimensions
# Quantized 4Bit weights packed.
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
self
.
bitblas_matmul
.
retrieve_weight_shape
(),
device
=
"cuda"
,
dtype
=
self
.
quant_config
.
storage_torch_dtype
,
requires_grad
=
False
,
),
input_dim
=
1
,
output_dim
=
0
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
bitblas_tile_size
=
(
self
.
bitblas_matmul
.
retrieve_weight_shape
()[
-
2
]
if
self
.
bitblas_matmul
.
propagate_b
else
None
),
weight_loader
=
weight_loader
,
)
# Compute the number of input groups for channel-wise quantization.
input_groups
=
(
1
if
group_size
==
-
1
else
input_size_per_partition
//
group_size
)
# Initialize scales and zeros for the quantized weights.
weight_scale_args
=
{
"data"
:
torch
.
empty
(
output_size_per_partition
,
input_groups
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
}
if
input_groups
==
1
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
**
weight_scale_args
)
if
self
.
quant_config
.
zeros_mode
==
"quantized"
:
zeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
self
.
quant_config
.
storage_torch_dtype
,
requires_grad
=
False
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
else
:
zeros
=
BasevLLMParameter
(
torch
.
empty
(
output_size_per_partition
,
input_groups
,
device
=
"cuda"
,
dtype
=
params_dtype
),
weight_loader
=
weight_loader
,
)
# Set attributes to indicate how scales and zeros are applied.
set_weight_attrs
(
zeros
,
{
"input_dim"
:
None
if
input_groups
==
1
else
1
,
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"zeros"
,
zeros
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
quant_config
.
quant_method
==
"gptq"
:
return
self
.
create_weights_gptq
(
layer
,
input_size_per_partition
,
output_partition_sizes
,
input_size
,
output_size
,
params_dtype
,
**
extra_weight_attrs
)
else
:
raise
ValueError
(
f
"Unsupported quant_method
{
self
.
quant_config
.
quant_method
}
"
)
def
_configure_bitblas_matmul
(
self
,
infeatures
,
outfeatures
,
params_dtype
,
enable_tuning
,
bias
,
layout
,
bits
,
out_dtype
=
"float16"
,
):
from
bitblas
import
MatmulConfig
bitblas_dtype
=
self
.
BITBLAS_DTYPES
[
params_dtype
]
with_scaling
=
False
with_zeros
=
False
group_size
=
self
.
quant_config
.
group_size
zeros_mode
=
self
.
quant_config
.
zeros_mode
if
self
.
quant_config
.
quant_method
==
"gptq"
:
with_scaling
=
True
with_zeros
=
True
W_dtype
=
f
"uint
{
bits
}
"
if
self
.
quant_config
.
is_sym
:
with_zeros
=
False
W_dtype
=
f
"int
{
bits
}
"
else
:
raise
ValueError
(
f
"Unsupported quant_method
{
self
.
quant_config
.
quant_method
}
"
)
matmul_config
=
MatmulConfig
(
N
=
outfeatures
,
K
=
infeatures
,
A_dtype
=
bitblas_dtype
,
W_dtype
=
W_dtype
,
out_dtype
=
out_dtype
,
accum_dtype
=
"int32"
if
bitblas_dtype
==
"int8"
else
bitblas_dtype
,
storage_dtype
=
self
.
quant_config
.
STORAGE_DTYPE
,
with_scaling
=
with_scaling
,
with_zeros
=
with_zeros
,
group_size
=
group_size
,
with_bias
=
bias
,
layout
=
layout
,
zeros_mode
=
zeros_mode
,
)
self
.
bitblas_matmul
=
self
.
_get_or_create_bitblas_operator
(
matmul_config
,
enable_tuning
)
def
_get_or_create_bitblas_operator
(
self
,
config
,
enable_tuning
):
from
bitblas
import
Matmul
,
auto_detect_nvidia_target
from
bitblas.cache
import
get_database_path
,
global_operator_cache
BITBLAS_DATABASE_PATH
=
get_database_path
()
BITBLAS_TARGET
=
auto_detect_nvidia_target
()
if
global_operator_cache
.
size
()
==
0
:
global_operator_cache
.
load_from_database
(
BITBLAS_DATABASE_PATH
,
BITBLAS_TARGET
)
bitblas_matmul
=
global_operator_cache
.
get
(
config
)
if
bitblas_matmul
is
None
:
bitblas_matmul
=
Matmul
(
config
,
target
=
BITBLAS_TARGET
,
enable_tuning
=
False
)
if
enable_tuning
:
TUNING_MESSAGE
=
(
f
"BitBLAS Operator
{
config
}
is tuning ..."
)
logger
.
info
(
TUNING_MESSAGE
)
bitblas_matmul
.
hardware_aware_finetune
(
topk
=
20
)
global_operator_cache
.
add
(
config
,
bitblas_matmul
)
global_operator_cache
.
save_into_database
(
BITBLAS_DATABASE_PATH
,
BITBLAS_TARGET
)
TUNED_MESSAGE
=
(
f
"BitBLAS Operator
{
config
}
tuned and saved to database."
)
logger
.
info
(
TUNED_MESSAGE
)
else
:
_message
=
f
"BitBLAS Operator
{
config
}
created."
logger
.
info
(
_message
)
else
:
_message
=
(
f
"BitBLAS Operator
{
config
}
found in global_operator_cache."
)
logger
.
info
(
_message
)
return
bitblas_matmul
def
apply_gptq
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
qzeros
=
layer
.
zeros
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
self
.
quant_config
.
is_sym
:
output_2d
=
self
.
bitblas_matmul
(
x_2d
,
qweight
,
scales
)
else
:
output_2d
=
self
.
bitblas_matmul
(
x_2d
,
qweight
,
scales
,
qzeros
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
def
apply
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
quant_method
==
"gptq"
:
return
self
.
apply_gptq
(
*
args
,
**
kwargs
)
else
:
raise
ValueError
(
f
"Unsupported quant_method
{
self
.
quant_config
.
quant_method
}
"
)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
081057de
...
@@ -72,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -72,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return
70
return
70
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"compressed
_
tensors"
return
"compressed
-
tensors"
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
...
@@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
input_quant_none
=
input_quant
is
None
input_quant_none
=
input_quant
is
None
is_symmetric
=
weight_quant
.
symmetric
is_channel_group
=
(
is_channel_group
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
or
weight_quant
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
)
or
weight_quant
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
)
is_static
=
not
weight_quant
.
dynamic
is_static
=
not
weight_quant
.
dynamic
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
return
(
is_channel_group
and
input_quant_none
and
is_static
)
and
is_static
)
def
_get_scheme_from_parts
(
def
_get_scheme_from_parts
(
self
,
weight_quant
:
BaseModel
,
self
,
weight_quant
:
BaseModel
,
...
@@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
assert
weight_quant
.
symmetric
return
CompressedTensorsW4A16Sparse24
(
return
CompressedTensorsW4A16Sparse24
(
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
num_bits
=
weight_quant
.
num_bits
,
num_bits
=
weight_quant
.
num_bits
,
...
@@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsWNA16
(
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
symmetric
=
weight_quant
.
symmetric
,
group_size
=
weight_quant
.
group_size
,
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
actorder
=
weight_quant
.
actorder
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
081057de
...
@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else
:
else
:
return
CompressedTensorsWNA16MarlinMoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MarlinMoEMethod
(
quant_config
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
and
layer
.
activation
==
"silu"
and
layer
.
expert_map
is
None
):
and
layer
.
activation
==
"silu"
):
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
...
@@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
requires_grad
=
False
)
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
)
# Property to determine if AITER is used
if
is_rocm_aiter_moe_enabled
():
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa E501
rocm_aiter_fused_experts
,
shuffle_weights
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
self
.
fused_experts_func
=
rocm_aiter_fused_experts
else
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
fused_experts_func
=
fused_experts
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -282,10 +303,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -282,10 +303,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
return
self
.
fused_experts
_func
(
x
,
hidden_states
=
x
,
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
...
@@ -489,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -489,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
assert
activation
==
"silu"
assert
global_num_experts
==
layer
.
w13_weight
.
shape
[
0
]
assert
expert_map
is
None
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -521,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -521,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
081057de
...
@@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
...
@@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_repeat_scales_on_all_ranks
)
marlin_repeat_scales_on_all_ranks
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
RowvLLMParameter
)
# yapf: enable
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
...
@@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
4
:
scalar_types
.
uint4b8
,
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
8
:
scalar_types
.
uint8b128
}
}
WNA16_ZP_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
...
@@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy
:
str
,
strategy
:
str
,
num_bits
:
int
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
,
group_size
:
Optional
[
int
]
=
None
,
symmetric
:
Optional
[
bool
]
=
True
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
symmetric
=
symmetric
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
...
@@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
WNA16_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
f
"Supported num_bits =
{
WNA16_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
]
self
.
quant_type
=
(
WNA16_ZP_SUPPORTED_TYPES_MAP
[
num_bits
]
if
not
self
.
symmetric
else
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
])
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_type
=
self
.
quant_type
,
weight_type
=
self
.
quant_type
,
act_type
=
params_dtype
,
act_type
=
params_dtype
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
zero_points
=
False
,
zero_points
=
not
self
.
symmetric
,
has_g_idx
=
self
.
has_g_idx
has_g_idx
=
self
.
has_g_idx
)
)
...
@@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype
=
params_dtype
,
dtype
=
params_dtype
,
)
)
}
}
zeros_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
zeros
(
output_size_per_partition
//
self
.
pack_factor
,
scales_and_zp_size
,
dtype
=
torch
.
int32
,
)
}
if
not
partition_scales
:
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
**
weight_scale_args
)
if
not
self
.
symmetric
:
qzeros
=
PackedColumnParameter
(
output_dim
=
0
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
**
zeros_args
)
else
:
else
:
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
input_dim
=
1
,
**
weight_scale_args
)
**
weight_scale_args
)
if
not
self
.
symmetric
:
qzeros
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
**
zeros_args
)
# A 2D array defining the original shape of the weights
# A 2D array defining the original shape of the weights
# before packing
# before packing
...
@@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
if
not
self
.
symmetric
:
layer
.
register_parameter
(
"weight_zero_point"
,
qzeros
)
# group index (for activation reordering)
# group index (for activation reordering)
if
self
.
has_g_idx
:
if
self
.
has_g_idx
:
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
...
@@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"weight_packed"
,
w_q_param_name
=
"weight_packed"
,
w_s_param_name
=
"weight_scale"
,
w_s_param_name
=
"weight_scale"
,
w_zp_param_name
=
None
,
w_zp_param_name
=
"weight_zero_point"
,
w_gidx_param_name
=
"weight_g_idx"
)
w_gidx_param_name
=
"weight_g_idx"
)
# Checkpoints are serialized in compressed-tensors format, which is
# Checkpoints are serialized in compressed-tensors format, which is
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
081057de
...
@@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
...
@@ -140,6 +140,11 @@ class Fp8Config(QuantizationConfig):
return
name
.
replace
(
".k_proj.output_scale"
,
".attn.k_scale"
)
return
name
.
replace
(
".k_proj.output_scale"
,
".attn.k_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".v_proj"
in
name
:
if
name
.
endswith
(
".output_scale"
)
and
".v_proj"
in
name
:
return
name
.
replace
(
".v_proj.output_scale"
,
".attn.v_scale"
)
return
name
.
replace
(
".v_proj.output_scale"
,
".attn.v_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".q_proj"
in
name
:
return
name
.
replace
(
".q_proj.output_scale"
,
".attn.q_scale"
)
if
name
.
endswith
(
"self_attn.prob_output_scale"
):
return
name
.
replace
(
".prob_output_scale"
,
".attn.prob_scale"
)
# If no matches, return None
return
None
return
None
...
@@ -575,8 +580,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -575,8 +580,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Lazy import to avoid importing triton too early.
# Lazy import to avoid importing triton too early.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
expand_weights
,
is_rocm_aiter_block_scaled_moe_enabled
,
expand_weights
,
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
# TODO (rob): refactor block quant into separate class.
# TODO (rob): refactor block quant into separate class.
if
self
.
block_quant
:
if
self
.
block_quant
:
...
@@ -603,7 +607,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -603,7 +607,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
Parameter
(
w2_weight_scale_inv
,
layer
.
w2_weight_scale_inv
=
Parameter
(
w2_weight_scale_inv
,
requires_grad
=
False
)
requires_grad
=
False
)
if
is_rocm_aiter_
block_scaled_
moe_enabled
():
if
is_rocm_aiter_moe_enabled
():
# reshaping weights is required for aiter moe kernel.
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
...
...
vllm/model_executor/layers/quantization/gptq_bitblas.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
BitBLASLinearKernel
,
MPLinearLayerConfig
)
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
BITBLAS_SUPPORTED_NUM_BITS
as
GPTQ_BITBLAS_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
BITBLAS_SUPPORTED_SYM
as
GPTQ_BITBLAS_SUPPORTED_SYM
)
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
MINIMUM_BITBLAS_VERSION
,
bitblas_repeat_scales_on_all_ranks
,
check_bitblas_supported
,
verify_bitblas_supported
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
class
GPTQBitBLASConfig
(
QuantizationConfig
):
"""Config class for GPTQ BitBLAS"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
TORCH_DTYPE
=
torch
.
float16
GPTQ_CKPT_STORAGE_DTYPE
=
(
"int32"
# GPTQ Default Checkpoints use int32 as storage dtype
)
GPTQ_BITBLAS_STORAGE_DTYPE
=
"int8"
# BitBLAS uses int8 as storage dtype
TORCH_BITBLAS_STORAGE_DTYPE
=
getattr
(
torch
,
GPTQ_BITBLAS_STORAGE_DTYPE
)
# "original" or "rescale" or "quantized",
# the gptq_bitblas prefer "quantized"
ZEROS_MODE
=
"quantized"
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
quant_method
:
Optional
[
str
],
lm_head_quantized
:
bool
,
)
->
None
:
try
:
import
bitblas
if
bitblas
.
__version__
<
MINIMUM_BITBLAS_VERSION
:
raise
ImportError
(
"bitblas version is wrong. Please "
f
"install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
"
)
except
ImportError
as
e
:
bitblas_import_exception
=
e
raise
ValueError
(
"Trying to use the bitblas backend, but could not import"
f
"with the following error:
{
bitblas_import_exception
}
. "
"Please install bitblas through the following command: "
f
"`pip install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
`"
)
from
bitblas_import_exception
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
quant_method
=
quant_method
self
.
lm_head_quantized
=
lm_head_quantized
# Verify
if
self
.
weight_bits
not
in
GPTQ_BITBLAS_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"BitBLAS does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_BITBLAS_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_BITBLAS_SUPPORTED_SYM
:
raise
ValueError
(
f
"BitBLAS does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_BITBLAS_SUPPORTED_SYM
}
are supported."
)
self
.
storage_dtype
=
self
.
GPTQ_BITBLAS_STORAGE_DTYPE
storage_nbit
=
int
(
""
.
join
(
c
for
c
in
self
.
GPTQ_CKPT_STORAGE_DTYPE
if
c
.
isdigit
()))
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
storage_nbit
//
weight_bits
self
.
nbits
=
weight_bits
# Zeros type for the quantized weights.
self
.
zeros_mode
=
self
.
ZEROS_MODE
if
(
weight_bits
,
is_sym
)
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
"Unsupported quantization config: "
f
"bits=
{
weight_bits
}
, sym=
{
is_sym
}
"
)
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQBitBLASConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
f
"is_sym=
{
self
.
is_sym
}
, "
f
"quant_method=
{
self
.
quant_method
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_bitblas"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQBitBLASConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
quant_method
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_gptq_bitblas_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"bitblas"
or
user_quant
==
"gptq_bitblas"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_bitblas"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_bitblas for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"GPTQBitBLASLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
GPTQBitBLASLinearMethod
(
self
)
return
None
@
property
def
torch_storage_dtype
(
self
)
->
torch
.
dtype
:
return
self
.
TORCH_BITBLAS_STORAGE_DTYPE
@
classmethod
def
is_gptq_bitblas_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
sym
=
quant_config
.
get
(
"sym"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
return
False
if
(
num_bits
,
sym
)
not
in
cls
.
TYPE_MAP
:
return
False
# If the capability of the device is too low, cannot convert.
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
cls
.
get_min_capability
():
return
False
# Otherwise, can convert if model satisfies bitblas constraints.
return
check_bitblas_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
group_size
=
group_size
)
class
GPTQBitBLASLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ BitBLAS.
Args:
quant_config: The GPTQ BitBLAS quantization config.
"""
kernel_type
=
BitBLASLinearKernel
_kernel_backends_being_used
:
Set
[
str
]
=
set
()
def
__init__
(
self
,
quant_config
:
GPTQBitBLASConfig
)
->
None
:
self
.
quant_config
=
quant_config
# Verify supported on platform.
verify_bitblas_supported
(
quant_type
=
self
.
quant_config
.
quant_type
,
group_size
=
self
.
quant_config
.
group_size
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing
quantized weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_partition_sizes: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or
if the input size per partition is not divisible by the
group size in `quant_config`.
"""
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
"Parameter data type must be torch.float16, "
f
"but got
{
params_dtype
}
"
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
if
input_size_per_partition
%
group_size
!=
0
:
raise
ValueError
(
f
"Input size per partition (
{
input_size_per_partition
}
) must "
f
"be divisible by group size (
{
self
.
quant_config
.
group_size
}
)."
)
kernel_type
=
self
.
kernel_type
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
mp_linear_kernel_config
=
MPLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
\
(
input_size_per_partition
,
output_size_per_partition
),
weight_type
=
self
.
quant_config
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
quant_config
.
group_size
,
zero_points
=
False
,
has_g_idx
=
self
.
quant_config
.
desc_act
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for GPTQBitBLASLinearMethod"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Determine sharding
if
bitblas_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
self
.
quant_config
.
group_size
,
is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
else
:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim
=
0
scales_and_zp_size
=
input_size_per_partition
//
group_size
# Init buffers
# Quantized weights
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
# Activation order
# Ignore warning from fused linear layers such as QKVParallelLinear.
g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
"weight_loader"
:
weight_loader
}
weight_scale_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
}
if
scales_and_zp_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"qweight"
,
w_s_param_name
=
"scales"
,
w_zp_param_name
=
"qzeros"
,
w_gidx_param_name
=
"g_idx"
,
bitblas_quant_config
=
self
.
quant_config
,
)
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self
.
kernel
.
configure_bitblas_matmul
(
input_size_per_partition
,
output_size_per_partition
,
params_dtype
=
params_dtype
,
bias
=
False
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
out
=
self
.
kernel
.
apply_gptq_bitblas_linear
(
layer
,
x
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
081057de
...
@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
from
vllm.model_executor.layers.quantization.kernels.mixed_precision
import
(
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.moe_wna16
import
MoeWNA16Config
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
)
get_linear_quant_method
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
marlin_moe_permute_scales
,
check_marlin_supported
,
check_moe_marlin_supports_layer
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
)
marlin_moe_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedColumnParameter
,
...
@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
if
layer
.
local_num_experts
>
32
:
from
vllm.model_executor.layers.quantization.moe_wna16
import
(
# For MoEs with many experts the moe_wna16 kernel is faster
MoeWNA16Config
)
if
not
check_moe_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_one
(
f
"Layer '
{
prefix
}
' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
return
MoeWNA16Config
.
from_config
(
return
MoeWNA16Config
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
else
:
return
GPTQMarlinMoEMethod
(
self
)
return
GPTQMarlinMoEMethod
(
self
)
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQMarlinLinearMethod
)
GPTQMarlinLinearMethod
)
...
@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch
.
empty
(
num_experts
,
torch
.
empty
(
num_experts
,
scales_size13
,
scales_size13
,
2
*
intermediate_size_per_partition
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
half
),
dtype
=
params_dtype
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
...
@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch
.
empty
(
num_experts
,
torch
.
empty
(
num_experts
,
scales_size2
,
scales_size2
,
hidden_size
,
hidden_size
,
dtype
=
torch
.
half
),
dtype
=
params_dtype
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
...
@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices
)
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
device
=
layer
.
w13_qweight
.
device
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
layer
.
workspace
=
torch
.
zeros
((
sms
*
4
,
),
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process act_order
# Process act_order
...
@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for"
"Apply router weight on input is not supported for"
"fused Marlin MoE method."
)
"fused Marlin MoE method."
)
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
g_idx1
=
layer
.
w13_g_idx
,
g_idx1
=
layer
.
w13_g_idx
,
g_idx2
=
layer
.
w2_g_idx
,
g_idx2
=
layer
.
w2_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
is_k_full
=
self
.
is_k_full
).
to
(
orig_dtype
)
workspace
=
layer
.
workspace
,
is_k_full
=
self
.
is_k_full
)
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
View file @
081057de
...
@@ -5,6 +5,8 @@ from typing import List, Optional, Type
...
@@ -5,6 +5,8 @@ from typing import List, Optional, Type
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark
import
(
# noqa: E501
AllSparkLinearKernel
)
AllSparkLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas
import
(
# noqa: E501
BitBLASLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama
import
(
# noqa: E501
ExllamaLinearKernel
)
ExllamaLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.machete
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.machete
import
(
# noqa: E501
...
@@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
...
@@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel
,
MacheteLinearKernel
,
AllSparkLinearKernel
,
AllSparkLinearKernel
,
MarlinLinearKernel
,
MarlinLinearKernel
,
BitBLASLinearKernel
,
ExllamaLinearKernel
,
ExllamaLinearKernel
,
]
]
...
@@ -76,4 +79,4 @@ def choose_mp_linear_kernel(
...
@@ -76,4 +79,4 @@ def choose_mp_linear_kernel(
raise
ValueError
(
raise
ValueError
(
"Failed to find a kernel that can implement the "
\
"Failed to find a kernel that can implement the "
\
"WNA16 linear layer. Reasons:
\n
"
"WNA16 linear layer. Reasons:
\n
"
+
'
\n
'
.
join
(
failure_reasons
))
+
'
\n
'
.
join
(
failure_reasons
))
\ No newline at end of file
vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
BITBLAS_OPTIMIZE_FEATURES
,
BITBLAS_SUPPORTED_GROUP_SIZES
,
MINIMUM_BITBLAS_VERSION
,
bitblas_make_empty_g_idx
,
bitblas_sort_g_idx
,
check_bitblas_supports_shape
,
query_bitblas_supported_quant_types
,
unpack_gptq_qweight
,
unpack_gptq_qzeros
)
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
logger
=
init_logger
(
__name__
)
class
BitBLASLinearKernel
(
MPLinearKernel
):
OPT_FEATURES
:
List
[
int
]
=
BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING
:
bool
=
True
MATMUL_LAYOUT
:
str
=
"nt"
BITBLAS_DTYPES
:
Dict
[
torch
.
dtype
,
str
]
=
{
torch
.
float32
:
"float32"
,
torch
.
float16
:
"float16"
,
torch
.
bfloat16
:
"bfloat16"
,
torch
.
half
:
"float16"
,
torch
.
int8
:
"int8"
,
}
bitblas_matmul
:
object
=
None
def
__init__
(
self
,
c
:
MPLinearLayerConfig
,
w_q_param_name
:
str
,
w_s_param_name
:
str
,
w_zp_param_name
:
Optional
[
str
]
=
None
,
w_gidx_param_name
:
Optional
[
str
]
=
None
,
bitblas_quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
self
.
quant_config
=
bitblas_quant_config
super
().
__init__
(
c
,
w_q_param_name
,
w_s_param_name
,
w_zp_param_name
,
w_gidx_param_name
)
def
repack_bitblas_from_gptq
(
self
,
b_q_weight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
Optional
[
torch
.
Tensor
]
=
None
,
):
from
bitblas.quantization.utils
import
general_compress
assert
self
.
bitblas_matmul
is
not
None
,
"bitblas_matmul is None"
quant_config
=
self
.
quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight
=
b_q_weight
.
T
.
contiguous
().
view
(
quant_config
.
torch_storage_dtype
)
# type: ignore[union-attr]
intweight
=
unpack_gptq_qweight
(
qweight
,
quant_config
.
weight_bits
).
contiguous
()
# type: ignore[union-attr]
if
self
.
bitblas_matmul
.
weight_transform
is
not
None
:
# type: ignore[attr-defined]
qweight
=
self
.
bitblas_matmul
.
weight_transform
(
# type: ignore[attr-defined]
intweight
.
cpu
()).
cuda
()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales
=
scales
.
T
.
contiguous
()
if
qzeros
is
None
:
return
qweight
,
scales
,
None
# qzeros should be de-quantized to int zeros.
weight_bits
=
quant_config
.
weight_bits
# type: ignore[union-attr]
intzeros
=
unpack_gptq_qzeros
(
qzeros
,
weight_bits
).
T
.
contiguous
()
zeros
:
Optional
[
torch
.
Tensor
]
=
None
zeros_mode
=
self
.
bitblas_matmul
.
config
.
zeros_mode
# type: ignore[attr-defined]
if
zeros_mode
==
"original"
:
zeros
=
intzeros
.
to
(
torch
.
float16
).
contiguous
()
elif
zeros_mode
==
"rescale"
:
assert
zeros
is
not
None
,
"zeros should not be None"
zeros
[:,
:]
=
intzeros
.
to
(
torch
.
float16
)[:,
:]
*
scales
[:,
:]
elif
zeros_mode
==
"quantized"
:
zeros
=
(
torch
.
Tensor
(
general_compress
(
intzeros
.
T
.
contiguous
().
cpu
().
numpy
(),
weight_bits
,
)).
to
(
qweight
.
device
).
to
(
quant_config
.
torch_storage_dtype
# type: ignore[union-attr]
).
contiguous
())
else
:
raise
ValueError
(
"Unsupported zeros type: {}"
.
format
(
zeros_mode
))
return
qweight
,
scales
,
zeros
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
is_bitblas_installed
=
True
try
:
import
bitblas
if
bitblas
.
__version__
<
MINIMUM_BITBLAS_VERSION
:
raise
ImportError
(
"bitblas version is wrong. Please "
f
"install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
"
)
except
ImportError
:
is_bitblas_installed
=
False
if
not
is_bitblas_installed
:
return
False
,
"bitblas is not installed. Please install bitblas "
\
"by running `pip install bitblas>="
\
f
"
{
MINIMUM_BITBLAS_VERSION
}
`"
quant_types
=
query_bitblas_supported_quant_types
(
c
.
zero_points
)
if
c
.
weight_type
not
in
quant_types
:
return
False
,
(
f
"Quant type (
{
c
.
weight_type
}
) not supported by"
f
" BitBLAS, supported types are:
{
quant_types
}
"
)
if
c
.
group_size
not
in
BITBLAS_SUPPORTED_GROUP_SIZES
:
return
False
,
(
f
"Group size (
{
c
.
group_size
}
) not supported by "
"BitBLAS, supported group sizes are: "
f
"
{
BITBLAS_SUPPORTED_GROUP_SIZES
}
"
)
return
check_bitblas_supports_shape
(
c
.
partition_weight_shape
[
1
],
# out_features
c
.
partition_weight_shape
[
0
],
# in_features
c
.
full_weight_shape
[
0
],
# in_features
c
.
group_size
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
c
=
self
.
config
quant_config
=
self
.
quant_config
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if
self
.
w_gidx_name
is
None
:
self
.
w_gidx_name
=
"g_idx"
if
self
.
w_zp_name
is
None
:
self
.
w_zp_name
=
"qzeros"
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
bitblas_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
))
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
bitblas_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
bitblas_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
raise
NotImplementedError
(
"Zero points not supported by BitBLAS"
)
else
:
setattr
(
layer
,
self
.
w_zp_name
,
bitblas_make_empty_g_idx
(
device
))
# Repack weights
bitblas_qweight
,
bitblas_scales
,
bitblas_qzeros
=
(
self
.
repack_bitblas_from_gptq
(
layer
.
qweight
,
layer
.
scales
,
None
if
quant_config
.
is_sym
else
# type: ignore[union-attr]
layer
.
qzeros
,
# type: ignore[union-attr]
))
replace_parameter
(
layer
,
self
.
w_q_name
,
bitblas_qweight
)
replace_parameter
(
layer
,
self
.
w_s_name
,
bitblas_scales
)
if
bitblas_qzeros
is
not
None
:
replace_parameter
(
layer
,
self
.
w_zp_name
,
bitblas_qzeros
)
def
configure_bitblas_matmul
(
self
,
infeatures
:
int
,
outfeatures
:
int
,
params_dtype
:
torch
.
dtype
,
bias
:
bool
,
)
->
None
:
enable_tuning
=
self
.
ENABLE_TUNING
layout
=
self
.
MATMUL_LAYOUT
bits
=
self
.
quant_config
.
weight_bits
# type: ignore[union-attr]
self
.
_configure_bitblas_matmul
(
infeatures
,
outfeatures
,
params_dtype
,
enable_tuning
,
bias
,
layout
,
bits
,
)
def
_configure_bitblas_matmul
(
self
,
infeatures
,
outfeatures
,
params_dtype
,
enable_tuning
,
bias
,
layout
,
bits
,
):
from
bitblas
import
MatmulConfig
bitblas_dtype
=
self
.
BITBLAS_DTYPES
[
params_dtype
]
quant_config
=
self
.
quant_config
with_scaling
=
False
with_zeros
=
False
group_size
=
quant_config
.
group_size
# type: ignore[union-attr]
zeros_mode
=
quant_config
.
zeros_mode
# type: ignore[union-attr]
if
quant_config
.
quant_method
==
"gptq"
:
# type: ignore[union-attr]
with_scaling
=
True
with_zeros
=
True
W_dtype
=
f
"uint
{
bits
}
"
if
quant_config
.
is_sym
:
# type: ignore[union-attr]
with_zeros
=
False
W_dtype
=
f
"int
{
bits
}
"
else
:
raise
ValueError
(
f
"Unsupported quant_method
{
quant_config
.
quant_method
}
"
# type: ignore[union-attr]
)
# type: ignore[union-attr]
matmul_config
=
MatmulConfig
(
M
=
self
.
OPT_FEATURES
,
N
=
outfeatures
,
K
=
infeatures
,
A_dtype
=
bitblas_dtype
,
W_dtype
=
W_dtype
,
out_dtype
=
bitblas_dtype
,
accum_dtype
=
"int32"
if
bitblas_dtype
==
"int8"
else
bitblas_dtype
,
storage_dtype
=
quant_config
.
# type: ignore[union-attr]
storage_dtype
,
# type: ignore[union-attr]
with_scaling
=
with_scaling
,
with_zeros
=
with_zeros
,
group_size
=
group_size
,
with_bias
=
bias
,
layout
=
layout
,
zeros_mode
=
zeros_mode
,
)
self
.
bitblas_matmul
=
self
.
_get_or_create_bitblas_operator
(
matmul_config
,
enable_tuning
)
def
_get_or_create_bitblas_operator
(
self
,
config
,
enable_tuning
):
from
bitblas
import
Matmul
,
auto_detect_nvidia_target
from
bitblas.cache
import
get_database_path
,
global_operator_cache
BITBLAS_DATABASE_PATH
=
get_database_path
()
BITBLAS_TARGET
=
auto_detect_nvidia_target
()
if
global_operator_cache
.
size
()
==
0
:
global_operator_cache
.
load_from_database
(
BITBLAS_DATABASE_PATH
,
BITBLAS_TARGET
)
bitblas_matmul
=
global_operator_cache
.
get
(
config
)
if
bitblas_matmul
is
None
:
bitblas_matmul
=
Matmul
(
config
,
target
=
BITBLAS_TARGET
,
enable_tuning
=
False
)
if
enable_tuning
:
bitblas_matmul
.
hardware_aware_finetune
(
topk
=
20
)
global_operator_cache
.
add
(
config
,
bitblas_matmul
)
global_operator_cache
.
save_into_database
(
BITBLAS_DATABASE_PATH
,
BITBLAS_TARGET
)
TUNING_MESSAGE
=
(
f
"BitBLAS Operator
{
config
}
tuned and saved to database."
)
logger
.
info
(
TUNING_MESSAGE
)
else
:
_message
=
f
"BitBLAS Operator
{
config
}
created without tuning. "
logger
.
info
(
_message
)
else
:
_message
=
f
"BitBLAS Operator
{
config
}
retrieved from cache."
logger
.
info
(
_message
)
return
bitblas_matmul
def
apply_gptq_bitblas_linear
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output_size_per_partition
=
self
.
config
.
partition_weight_shape
[
1
]
out_shape
=
x
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
args
=
[
x
,
layer
.
qweight
,
layer
.
scales
]
if
self
.
bitblas_matmul
.
config
.
with_zeros
:
# type: ignore[attr-defined]
args
.
append
(
layer
.
qzeros
)
output
=
self
.
bitblas_matmul
(
*
args
)
# type: ignore[operator]
return
output
.
view
(
out_shape
)
def
apply_weights
(
self
,
layer
,
x
,
bias
=
None
):
NOT_IMPLEMENT_MESSAGE
=
(
f
"
{
self
.
__class__
.
__name__
}
.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead"
)
raise
NotImplementedError
(
NOT_IMPLEMENT_MESSAGE
)
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
View file @
081057de
...
@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
...
@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@
classmethod
@
classmethod
def
can_implement
(
cls
,
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
and
\
if
c
.
has_g_idx
and
\
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
return
False
,
"Act reordering currently not supported by Machete, "
\
return
False
,
"Act reordering currently not supported by Machete, "
\
"when the input features are partitioned across "
\
"when the input features are partitioned across "
\
"devices"
"devices"
if
c
.
zero_points
:
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by "
\
return
False
,
"Zero points currently not supported by Machete"
" Compressed Tensors + Machete. (Kernel supports it"
\
" but CompressedTensorsWNA16 does not so support has"
\
" not been added to MacheteWNA16Kernel yet"
if
c
.
weight_type
not
in
query_machete_supported_quant_types
(
if
c
.
weight_type
not
in
query_machete_supported_quant_types
(
c
.
zero_points
):
c
.
zero_points
):
...
...
Prev
1
…
20
21
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment