Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
448
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1078 additions
and
501 deletions
+1078
-501
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+93
-54
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+180
-136
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+23
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+269
-3
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+9
-5
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+7
-0
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+5
-7
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+46
-52
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+35
-44
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+25
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+9
-12
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+11
-9
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+53
-59
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+32
-19
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+31
-20
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+27
-20
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+48
-54
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+175
-0
No files found.
Too many changes to show.
To preserve performance only
448 of 448+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
af7f4372
File moved
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
af7f4372
File moved
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
af7f4372
...
...
@@ -11,48 +11,51 @@ import triton.language as tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8
:
tl
.
constexpr
,
):
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsn
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
...
...
@@ -113,8 +116,12 @@ def fused_moe_kernel(
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_int8_w8a16
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8
:
if
use_fp8
_w8a8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
...
@@ -136,7 +143,9 @@ def fused_moe_kernel(
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_fp8
:
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
...
...
@@ -149,8 +158,9 @@ def fused_moe_kernel(
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
...
...
@@ -229,16 +239,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
)
->
None
:
use_fp8
_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
not
use_fp8
:
assert
A_scale
is
None
assert
B_scale
is
None
else
:
if
use_fp8_w8a8
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
elif
use_int8_w8a16
:
assert
B_scale
is
not
None
else
:
assert
A_scale
is
None
assert
B_scale
is
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
...
...
@@ -264,16 +276,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
str
:
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
dtype_selector
=
""
if
not
dtype
else
f
",dtype=
{
dtype
}
"
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}
.json"
...
...
@@ -426,6 +441,20 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return
"float32"
return
None
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -433,7 +462,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -454,13 +484,16 @@ def fused_experts(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
Non
e
,
config_dtyp
e
,
override_config
=
override_config
,
)
...
...
@@ -524,7 +557,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
...
@@ -542,7 +576,8 @@ def fused_experts(hidden_states: torch.Tensor,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
...
...
@@ -562,7 +597,8 @@ def fused_moe(
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -588,7 +624,9 @@ def fused_moe(
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
...
...
@@ -617,7 +655,8 @@ def fused_moe(
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8
=
use_fp8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
af7f4372
...
...
@@ -24,15 +24,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -61,66 +55,78 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
def
forward_tpu
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
)
return
fused_moe
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk
=
top_k
,
gating_output
=
router_logits
,
renormalize
=
renormalize
)
class
FusedMoE
(
torch
.
nn
.
Module
):
...
...
@@ -195,52 +201,83 @@ class FusedMoE(torch.nn.Module):
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if
shard_id
==
0
or
shard_id
==
2
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
0
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else
:
param_data
[
expert_id
]
=
loaded_weight
# Weights
shard_id
:
str
,
expert_id
:
int
)
->
None
:
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
# Special case for fp8 scales.
if
getattr
(
param
,
"is_fp8_scale"
,
False
):
self
.
_load_fp8_scale
(
param
.
data
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
# Default is not transposed/input dim is dim 1
input_dim
=
getattr
(
param
,
"input_dim"
,
1
)
output_dim
=
getattr
(
param
,
"output_dim"
,
0
)
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
if
shard_id
==
"w2"
:
shard_dim
=
input_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif
shard_id
in
(
"w1"
,
"w3"
):
shard_dim
=
output_dim
shard_size
=
expert_data
.
shape
[
output_dim
]
//
2
offset
=
shard_size
*
tp_rank
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
offset
,
shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
# w3, up_proj: Load into second logical weight of w13.
elif
shard_id
==
"w3"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
# w2, down_proj: Load into only logical weight of w2.
elif
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
0
:
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
2
:
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
1
:
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
raise
ValueError
(
f
"Expected shard_id w1,w2 or w3 but got
{
shard_id
}
"
)
@
staticmethod
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
use_grouped_topk
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
)
return
topk_weights
,
topk_ids
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
...
...
@@ -248,14 +285,14 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
self
,
layer
=
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
num_expert
_group
=
self
.
num_expert
_group
,
topk
_group
=
self
.
topk
_group
)
topk
_group
=
self
.
topk
_group
,
num_expert
_group
=
self
.
num_expert
_group
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
...
@@ -267,35 +304,42 @@ class FusedMoE(torch.nn.Module):
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
)
->
List
[
Tuple
[
str
,
str
,
int
,
int
]]:
gate_up
=
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
gate_down_up
=
[
ckpt_gate_proj_name
,
ckpt_down_proj_name
,
ckpt_up_proj_name
]
num_experts
:
int
)
->
List
[
Tuple
[
str
,
str
,
int
,
str
]]:
return
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_scale"
if
weight_name
in
gate_up
else
"experts.w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
gate_up
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.a13_scale"
if
weight_name
in
gate_up
else
"experts.a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
(
"experts.w13_"
if
weight_name
in
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
."
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
[
(
"w1"
,
ckpt_gate_proj_name
),
(
"w2"
,
ckpt_down_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
]
]
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
else
:
param_data
[
expert_id
]
=
loaded_weight
vllm/model_executor/layers/layernorm.py
View file @
af7f4372
...
...
@@ -131,10 +131,12 @@ class GemmaRMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward_native
(
self
,
@
staticmethod
def
forward_static
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype
=
x
.
dtype
...
...
@@ -144,17 +146,32 @@ class GemmaRMSNorm(CustomOp):
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x
=
x
*
(
1.0
+
self
.
weight
.
float
())
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
if
residual
is
None
else
(
x
,
residual
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
return
self
.
forward_static
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
,
residual
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
x
,
residual
)
if
not
getattr
(
self
,
"_is_compiled"
,
False
):
self
.
forward_static
=
torch
.
compile
(
# type: ignore
self
.
forward_static
)
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
vllm/model_executor/layers/linear.py
View file @
af7f4372
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
...
...
@@ -13,6 +13,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
import
os
...
...
@@ -20,6 +23,12 @@ from vllm.model_executor.utils import gemm_bank_conf
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
]
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
...
...
@@ -307,6 +316,7 @@ class ColumnParallelLinear(LinearBase):
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size
,
...
...
@@ -314,7 +324,9 @@ class ColumnParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
weight_loader
=
(
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
prefix
=
prefix
)
if
bias
:
self
.
bias
=
Parameter
(
...
...
@@ -330,6 +342,17 @@ class ColumnParallelLinear(LinearBase):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
:
param
.
weight_type
=
loaded_weight
.
item
()
# Materialize GGUF UninitializedParameter
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
output_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
output_dim
]
...
...
@@ -345,6 +368,14 @@ class ColumnParallelLinear(LinearBase):
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
weight_loader_v2
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
...
...
@@ -417,6 +448,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
:
param
.
data
[
loaded_shard_id
].
copy_
(
loaded_weight
)
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
))
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
...
...
@@ -479,6 +531,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
...
...
@@ -507,6 +571,65 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
"""
Handle special case for models where MLP layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
current_shard_offset
=
0
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
isinstance
(
param
,
PackedvLLMParameter
)
and
param
.
packed_dim
==
param
.
output_dim
:
shard_size
,
shard_offset
=
\
param
.
adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
param
.
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader_v2
(
param
,
loaded_weight_shard
,
shard_id
)
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
int
]
=
None
):
if
loaded_shard_id
is
None
:
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
is
BasevLLMParameter
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
return
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_size
=
get_tensor_model_parallel_world_size
()
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
...
...
@@ -578,10 +701,112 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config
=
quant_config
,
prefix
=
prefix
)
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
"q"
:
0
,
"k"
:
self
.
num_heads
*
self
.
head_size
,
"v"
:
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
"total"
:
(
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
}
return
shard_offset_mapping
.
get
(
loaded_shard_id
)
def
_get_shard_size_mapping
(
self
,
loaded_shard_id
:
str
):
shard_size_mapping
=
{
"q"
:
self
.
num_heads
*
self
.
head_size
,
"k"
:
self
.
num_kv_heads
*
self
.
head_size
,
"v"
:
self
.
num_kv_heads
*
self
.
head_size
,
}
return
shard_size_mapping
.
get
(
loaded_shard_id
)
def
_load_fused_module_from_checkpoint
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determmines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets
=
[
# (shard_id, shard_offset, shard_size)
(
"q"
,
0
,
self
.
total_num_heads
*
self
.
head_size
),
(
"k"
,
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
]
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
isinstance
(
param
,
PackedvLLMParameter
)
and
param
.
packed_dim
==
param
.
output_dim
:
shard_size
,
shard_offset
=
\
param
.
adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
param
.
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader_v2
(
param
,
loaded_weight_shard
,
shard_id
)
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
if
loaded_shard_id
is
None
:
# special case for certain models
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
is
BasevLLMParameter
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
return
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
shard_offset
=
self
.
_get_shard_offset_mapping
(
loaded_shard_id
)
shard_size
=
self
.
_get_shard_size_mapping
(
loaded_shard_id
)
param
.
load_qkv_weight
(
loaded_weight
=
loaded_weight
,
num_heads
=
self
.
num_kv_head_replicas
,
shard_id
=
loaded_shard_id
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
and
loaded_shard_id
is
not
None
:
idx_map
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
param
.
data
[
idx_map
[
loaded_shard_id
]].
copy_
(
loaded_weight
)
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
))
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
...
...
@@ -669,6 +894,18 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
...
...
@@ -748,6 +985,7 @@ class RowParallelLinear(LinearBase):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size_per_partition
,
...
...
@@ -755,7 +993,9 @@ class RowParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
weight_loader
=
(
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
prefix
=
prefix
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
...
...
@@ -773,7 +1013,22 @@ class RowParallelLinear(LinearBase):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
if
is_gguf_weight_type
:
param
.
weight_type
=
loaded_weight
.
item
()
# Materialize GGUF UninitializedParameter
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
weight_shape
=
list
(
loaded_weight
.
shape
)
if
input_dim
:
weight_shape
[
input_dim
]
=
weight_shape
[
input_dim
]
//
tp_size
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
input_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
input_dim
]
...
...
@@ -789,6 +1044,17 @@ class RowParallelLinear(LinearBase):
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
):
if
self
.
input_is_parallel
:
input_parallel
=
input_
...
...
vllm/model_executor/layers/logits_processor.py
View file @
af7f4372
...
...
@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Optional
[
torch
.
Tensor
]
:
if
self
.
logits_as_input
:
logits
=
hidden_states
else
:
...
...
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return
logits
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
if
self
.
use_gather
:
# None may be returned for rank > 0
logits
=
tensor_model_parallel_gather
(
logits
)
else
:
# Gather is not supported for some devices such as TPUs.
...
...
@@ -91,7 +95,7 @@ class LogitsProcessor(nn.Module):
logits
=
tensor_model_parallel_all_gather
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[
:
,
:
self
.
org_vocab_size
]
logits
=
logits
[
...
,
:
self
.
org_vocab_size
]
return
logits
def
extra_repr
(
self
)
->
str
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
af7f4372
...
...
@@ -11,8 +11,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
)
from
vllm.model_executor.layers.quantization.experts_int8
import
(
ExpertsInt8Config
)
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
...
...
@@ -21,16 +24,19 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
...
...
@@ -39,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
}
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
af7f4372
...
...
@@ -95,7 +95,7 @@ def generic_dequantize_gemm(
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
output_partition_sizes
:
torch
.
IntTensor
,
output_partition_sizes
:
List
[
int
]
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
output_shape
=
input
.
shape
[:
-
1
]
+
(
scales
.
shape
[
0
],
)
...
...
@@ -133,7 +133,7 @@ def optimized_dequantize_gemm(
codebooks
:
torch
.
Tensor
,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
scales
:
torch
.
Tensor
,
# [num_out_groups, 1, 1, 1]
output_partition_sizes
:
torch
.
IntTensor
,
output_partition_sizes
:
List
[
int
]
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
weights
=
ops
.
aqlm_dequant
(
codes
,
codebooks
,
output_partition_sizes
)
...
...
@@ -288,10 +288,8 @@ class AQLMLinearMethod(LinearMethodBase):
codebooks
,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata"
:
True
,
"output_partition_sizes"
:
torch
.
tensor
(
output_partition_sizes
,
device
=
'cpu'
),
"is_metadata"
:
True
,
"output_partition_sizes"
:
output_partition_sizes
},
)
...
...
@@ -334,7 +332,7 @@ class AQLMLinearMethod(LinearMethodBase):
codes
=
layer
.
codes
scales
=
layer
.
scales
output_partition_sizes
=
getattr
(
codebooks
,
"output_partition_sizes"
,
None
)
[]
)
nbooks
=
codes
.
shape
[
2
]
ingroups
=
codebooks
.
shape
[
3
]
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
af7f4372
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
class
AWQShareWorkSpace
:
...
...
@@ -117,70 +117,64 @@ class AWQLinearMethod(LinearMethodBase):
"weight shape. This can be caused by too large "
"tensor parallel size."
)
qweight
=
Parameter
(
torch
.
empty
(
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
qzeros
=
Parameter
(
torch
.
empty
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
qzeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
)
zeros_and_scales
=
Parameter
(
torch
.
empty
(
(
input_size_per_partition
//
self
.
quant_config
.
group_size
),
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
zeros_and_scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
zeros_and_scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"zeros_and_scales"
,
zeros_and_scales
)
set_weight_attrs
(
zeros_and_scales
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
layer
.
zeros_and_scales
=
torch
.
nn
.
Parameter
(
layer
.
zeros_and_scales
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
af7f4372
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
...
@@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -126,8 +126,7 @@ class AWQMarlinConfig(QuantizationConfig):
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[
num_bits
],
group_size
=
group_size
,
has_zp
=
has_zp
,
min_capability
=
cls
.
get_min_capability
())
has_zp
=
has_zp
)
class
AWQMarlinLinearMethod
(
LinearMethodBase
):
...
...
@@ -152,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
...
...
@@ -165,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
input_size
=
input_size
,
group_size
=
group_size
)
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
Parameter
(
torch
.
empty
(
qzeros
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
...
...
@@ -229,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Here, we handle the repacking
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
...
...
@@ -279,4 +270,4 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
bias
=
bias
)
bias
=
bias
)
\ No newline at end of file
vllm/model_executor/layers/quantization/base_config.py
View file @
af7f4372
import
inspect
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
import
torch
from
torch
import
nn
...
...
@@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC):
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
# Not required functions
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
...
...
@@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC):
return
def
method_has_implemented_embedding
(
method_class
:
Type
[
QuantizeMethodBase
])
->
bool
:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding
=
inspect
.
getattr_static
(
QuantizeMethodBase
,
"embedding"
,
None
)
class_embedding
=
inspect
.
getattr_static
(
method_class
,
"embedding"
,
None
)
return
(
class_embedding
is
not
None
and
class_embedding
is
not
base_embedding
)
class
QuantizationConfig
(
ABC
):
"""Base class for quantization configs."""
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
af7f4372
...
...
@@ -19,6 +19,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
__all__
=
[
"CompressedTensorsLinearMethod"
]
class
CompressedTensorsConfig
(
QuantizationConfig
):
...
...
@@ -146,18 +148,15 @@ class CompressedTensorsConfig(QuantizationConfig):
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
# Confirm we have floating points.
if
not
(
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
False
# Confirm weight scheme is supported.
is_floating_point
=
(
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
input_quant
.
type
==
QuantizationType
.
FLOAT
)
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
])
if
not
(
is_symmetric_weight
and
is_static_weight
if
not
(
is_floating_point
and
is_symmetric_weight
and
is_static_weight
and
is_per_tensor_or_channel_weight
):
return
False
...
...
@@ -169,11 +168,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_symmetric_activation
=
input_quant
.
symmetric
is_per_tensor_activation
=
(
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
if
not
(
is_symmetric_activation
and
is_per_tensor_activation
):
return
False
# All conditions satisfied.
return
True
return
is_symmetric_activation
and
is_per_tensor_activation
def
_is_fp8_w8a16
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
...
...
@@ -230,6 +225,7 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size
=
weight_quant
.
group_size
)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if
is_activation_quantization_format
(
self
.
quant_format
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
is_fp8_w8a8_supported
=
self
.
_check_scheme_supported
(
...
...
@@ -237,7 +233,8 @@ class CompressedTensorsConfig(QuantizationConfig):
if
is_fp8_w8a8_supported
:
return
CompressedTensorsW8A8Fp8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
not
input_quant
.
dynamic
))
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
else
:
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
af7f4372
...
...
@@ -2,11 +2,10 @@ from typing import Callable, List, Optional
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.
utils
import
set_weight_attrs
from
vllm.model_executor.
parameter
import
ModelWeightParameter
__all__
=
[
"CompressedTensorsUnquantized"
]
...
...
@@ -24,7 +23,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
return
70
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
# required by torch.compile to be torch.nn.Parameter
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
...
...
@@ -32,14 +33,15 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
af7f4372
...
...
@@ -8,7 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsW4A16Sparse24"
]
...
...
@@ -45,7 +48,12 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
return
80
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_packed
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
layer
.
scale_packed
=
Parameter
(
layer
.
scale_packed
.
data
,
requires_grad
=
False
)
layer
.
meta
=
Parameter
(
layer
.
meta
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
...
...
@@ -56,79 +64,65 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
pack_factor
=
32
//
self
.
quant_type
.
size_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
tile_size
//
2
,
output_size_per_partition
*
self
.
tile_size
//
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
pack_factor
,
"marlin_tile_size"
:
self
.
tile_size
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"weight_packed"
,
qweight
)
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
tile_size
//
2
,
output_size_per_partition
*
self
.
tile_size
//
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
pack_factor
,
marlin_tile_size
=
self
.
tile_size
,
weight_loader
=
weight_loader
)
input_groups
=
(
1
if
self
.
group_size
is
None
else
input_size_per_partition
//
self
.
group_size
)
scales
=
Parameter
(
weight_scale_args
=
{
"data"
:
torch
.
empty
(
input_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"output_dim"
:
1
,
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"scale_packed"
,
scales
)
weight_shape
=
Parameter
(
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
requires_grad
=
False
)
"weight_loader"
:
weight_loader
}
if
self
.
group_size
is
not
None
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
else
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
weight_shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
meta
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
dtype
=
torch
.
int16
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
1
,
marlin_tile_size
=
2
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_packed"
,
qweight
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
set_weight_attrs
(
weight_shape
,
{
"weight_loader"
:
weight_loader
})
meta
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
dtype
=
torch
.
int16
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
meta
,
{
"input_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
1
,
"output_dim"
:
1
,
"marlin_tile_size"
:
2
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"scale_packed"
,
scales
)
layer
.
register_parameter
(
"meta"
,
meta
)
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_24_MIN_THREAD_N
)
*
GPTQ_MARLIN_24_MAX_PARALLEL
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
workspace
=
workspace
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
af7f4372
...
...
@@ -9,9 +9,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
convert_to_channelwise
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
...
...
@@ -40,11 +41,19 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer
.
logical_widths
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
else
:
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
# Weights must be transposed for marlin
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
if
self
.
is_static_input_scheme
:
# required by torch.compile to be torch.nn.Parameter
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
prepare_fp8_layer_for_marlin
(
layer
,
strategy
=
"channel"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
...
...
@@ -60,35 +69,39 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
elif
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
else
:
raise
ValueError
(
f
"Unsupported weight strategy=
{
self
.
strategy
}
, "
f
"supported strategies are
{
SUPPORTED_STRATEGIES
}
"
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE (to deal with converted checkpoints)
if
self
.
is_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
af7f4372
...
...
@@ -8,10 +8,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
c
reate_per_channel_scale_param
,
create_per_tensor_s
cale
_p
aram
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
apply_fp8_linear
,
c
utlass_fp8_supported
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ChannelQuantS
cale
P
aram
eter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
...
...
@@ -46,6 +46,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# required by torch.compile to be torch.nn.Parameter
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
...
...
@@ -66,32 +69,40 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
# min requirement for fp8 kernels
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
input_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
af7f4372
...
...
@@ -8,9 +8,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_int8_linear
,
convert_to_channelwise
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
apply_int8_linear
,
convert_to_channelwise
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
...
...
@@ -39,7 +41,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
ws_channelwise
=
convert_to_channelwise
(
layer
.
weight_scale
,
self
.
logical_widths
)
layer
.
weight_scale
=
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
else
:
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
...
...
@@ -55,32 +59,35 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
af7f4372
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
...
...
@@ -10,7 +9,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsWNA16"
]
...
...
@@ -30,17 +32,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
group_size
:
int
if
group_size
is
None
:
if
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise."
)
self
.
group_size
=
-
1
else
:
self
.
group_size
=
group_size
if
self
.
group_size
==
-
1
and
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise."
)
if
num_bits
not
in
WNA16_SUPPORTED_TYPES_MAP
:
raise
ValueError
(
...
...
@@ -63,11 +60,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
# If group_size is -1, we are in channelwise case.
channelwise
=
(
self
.
group_size
==
-
1
)
group_size
=
input_size
if
channelwise
else
self
.
group
_size
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input
_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
...
...
@@ -79,60 +77,51 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
input_size
=
input_size
,
group_size
=
group_size
)
weight_scale_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
if
partition_scales
:
assert
input_size_per_partition
%
group_size
==
0
weight_scale_dim
=
1
scales_and_zp_size
=
input_size_per_partition
//
group_size
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
pack_factor
,
"weight_loader"
:
weight_loader
})
layer
.
register_parameter
(
"weight_packed"
,
weight
)
weight_scale
=
Parameter
(
weight
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
packed_factor
=
self
.
pack_factor
,
packed_dim
=
1
,
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
))
weight_scale_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
,
"input_dim"
:
weight_scale_dim
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
)
}
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
else
:
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
**
weight_scale_args
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape
=
Parameter
(
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
requires_grad
=
False
)
weight_shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
set_weight_attrs
(
weight_shape
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
...
...
@@ -154,10 +143,15 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
# Update for kernel
layer
.
weight_packed
=
torch
.
nn
.
Parameter
(
layer
.
weight_packed
.
t
().
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
requires_grad
=
False
)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
.
t
().
contiguous
()
,
layer
.
weight_packed
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
...
...
@@ -166,7 +160,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
# Permute scales from compressed-tensors format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
()
,
layer
.
weight_scale
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
...
...
vllm/model_executor/layers/quantization/experts_int8.py
0 → 100644
View file @
af7f4372
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
ExpertsInt8Config
(
QuantizationConfig
):
"""Config class for Int8 experts quantization."""
def
__init__
(
self
)
->
None
:
pass
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"experts_int8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ExpertsInt8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
ExpertsInt8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ExpertsInt8MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
ExpertsInt8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
int8_dtype
=
torch
.
int8
assert
'weight_loader'
in
extra_weight_attrs
weight_loader
=
extra_weight_attrs
[
'weight_loader'
]
wrapped_weight_loader
=
ExpertsInt8MoEMethod
.
quantizing_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
'weight_loader'
]
=
wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
int8_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
int8_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int8_w8a16
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
@
staticmethod
def
quantizing_weight_loader
(
layer
,
weight_loader
):
def
quantize_and_call_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
layer
.
intermediate_size_per_partition
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
device
=
get_tp_group
().
device
loaded_weight
=
loaded_weight
.
to
(
device
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
"w1"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[
shard
,
:])
layer
.
w13_scale
.
data
[
expert_id
,
0
:
shard_size
].
copy_
(
scales
[:,
0
])
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
"w3"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[
shard
,
:])
layer
.
w13_scale
.
data
[
expert_id
,
shard_size
:
2
*
shard_size
].
copy_
(
scales
[:,
0
])
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
"w2"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[:,
shard
])
layer
.
w2_scale
.
data
[
expert_id
,
:].
copy_
(
scales
[:,
0
])
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
quantize_and_call_weight_loader
def
quantize_in_place_and_get_scales
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
vmax
=
torch
.
iinfo
(
torch
.
int8
).
max
scales
=
(
torch
.
max
(
torch
.
abs
(
weight
),
dim
=
1
,
keepdim
=
True
)[
0
]
/
vmax
)
weight
.
div_
(
scales
)
weight
.
round_
()
weight
.
clamp_
(
-
vmax
,
vmax
)
return
scales
Prev
1
…
16
17
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment