Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6cd5e5b0
Unverified
Commit
6cd5e5b0
authored
Sep 09, 2024
by
Dipika Sikka
Committed by
GitHub
Sep 09, 2024
Browse files
[Misc] Fused MoE Marlin support for GPTQ (#8217)
parent
c7cb5c33
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
912 additions
and
204 deletions
+912
-204
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+12
-1
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+1
-1
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+1
-1
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+0
-1
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+217
-4
tests/weight_loading/models-large.txt
tests/weight_loading/models-large.txt
+3
-0
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+0
-2
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+10
-4
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+219
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+20
-118
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+51
-24
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+28
-20
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+297
-15
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+17
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+7
-4
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+13
-6
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+8
-0
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+7
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
6cd5e5b0
...
@@ -386,7 +386,18 @@ steps:
...
@@ -386,7 +386,18 @@ steps:
-
vllm/
-
vllm/
-
tests/weight_loading
-
tests/weight_loading
commands
:
commands
:
-
bash weight_loading/run_model_weight_loading_test.sh
-
bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
-
label
:
Weight Loading Multiple GPU Test - Large Models
# optional
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
gpu
:
a100
optional
:
true
source_file_dependencies
:
-
vllm/
-
tests/weight_loading
commands
:
-
bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
##### multi gpus test #####
##### multi gpus test #####
...
...
csrc/moe/marlin_moe_ops.cu
View file @
6cd5e5b0
...
@@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
...
@@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
moe_block_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
return
c
;
}
}
\ No newline at end of file
csrc/moe/marlin_moe_ops.h
View file @
6cd5e5b0
...
@@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
...
@@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
bool
replicate_input
,
bool
apply_weights
);
\ No newline at end of file
csrc/moe/torch_bindings.cpp
View file @
6cd5e5b0
...
@@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor"
);
"bool replicate_input, bool apply_weights) -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
#endif
#endif
}
}
...
...
tests/kernels/test_moe.py
View file @
6cd5e5b0
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
Run `pytest tests/kernels/test_moe.py`.
"""
"""
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
...
@@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...
@@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
...
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
def
torch_moe_single
(
a
,
w
,
score
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
_
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
a
[
mask
]
@
w
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w
.
shape
[
1
])).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1024
*
128
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
...
@@ -43,11 +65,11 @@ def test_fused_moe(
...
@@ -43,11 +65,11 @@ def test_fused_moe(
topk
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
):
a
=
torch
.
randn
((
m
,
k
),
device
=
'
cuda
'
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"
cuda
"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
'
cuda
'
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"
cuda
"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
'
cuda
'
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"
cuda
"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
'
cuda
'
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"
cuda
"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
...
@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states
,
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
atol
=
mixtral_moe_tol
[
dtype
])
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
dev
=
tensors
[
0
].
device
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
dev
)
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
def
test_fused_marlin_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
torch
.
manual_seed
(
7
)
if
topk
>
e
:
return
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
in
(
k
,
n
):
return
quant_type
=
scalar_types
.
uint4b8
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
for
i
in
range
(
w2
.
shape
[
0
]):
w2
[
0
]
=
torch
.
eye
(
k
,
n
,
device
=
"cuda"
,
dtype
=
dtype
)
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
k
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
g_idx1
=
stack_and_dev
(
g_idx1_l
)
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
g_idx2
=
stack_and_dev
(
g_idx2_l
)
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
triton_output
=
fused_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
).
contiguous
(),
w_ref2
.
transpose
(
1
,
2
).
contiguous
(),
score
,
topk
,
renormalize
=
False
,
)
marlin_output
=
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
score
,
g_idx1
,
g_idx2
,
sort_indices1
,
sort_indices2
,
topk_weights
,
topk_ids
,
w1_scale
=
scales1
,
w2_scale
=
scales2
,
)
assert
compute_max_diff
(
marlin_output
,
triton_output
)
<
4e-2
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
4
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"act_order"
,
[
True
,
False
])
def
test_marlin_moe_mmm
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
if
topk
>
e
:
return
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
==
k
:
return
quant_type
=
scalar_types
.
uint4b8
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref_l
=
[]
qweights_l
=
[]
scales_l
=
[]
g_idx_l
=
[]
sort_indices_l
=
[]
for
i
in
range
(
w
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
k
)
w_ref
,
qweight
,
scales
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref_l
.
append
(
w_ref
)
qweights_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
g_idx_l
.
append
(
g_idx
)
sort_indices_l
.
append
(
sort_indices
)
w_ref
=
stack_and_dev
(
w_ref_l
)
qweight
=
stack_and_dev
(
qweights_l
).
contiguous
()
scales
=
stack_and_dev
(
scales_l
)
g_idx
=
stack_and_dev
(
g_idx_l
)
sort_indices
=
stack_and_dev
(
sort_indices_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
g_idx
,
sort_indices
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe_single
(
a
,
w_ref
.
transpose
(
1
,
2
),
score
,
topk
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
1e-2
tests/weight_loading/models-large.txt
0 → 100644
View file @
6cd5e5b0
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
\ No newline at end of file
tests/weight_loading/models.txt
View file @
6cd5e5b0
...
@@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
...
@@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
6cd5e5b0
...
@@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
...
@@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
]
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
,
]
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_
marlin_
moe
,
fused_
moe
,
fused_topk
,
fused_experts
,
fused_moe
,
fused_
topk
,
get_config_file_name
,
get_config_file_name
,
grouped_topk
)
grouped_topk
)
__all__
+=
[
__all__
+=
[
"fused_marlin_moe"
,
"fused_marlin_moe"
,
"single_marlin_moe"
,
"fused_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_topk"
,
"fused_experts"
,
"fused_experts"
,
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
0 → 100644
View file @
6cd5e5b0
"""Fused MoE utilities for GPTQ."""
import
functools
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
torch
.
Tensor
:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (torch.Tensor): The act_order indices.
- perm (torch.Tensor): The act_order input permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w
.
shape
[
1
]
*
16
,
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w
.
is_contiguous
(),
"Expert weights must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
M
,
K
=
hidden_states
.
shape
E
=
w
.
shape
[
0
]
N
=
w
.
shape
[
2
]
//
2
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
# This might not be an optimal config for a single MMM
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w
.
shape
,
w
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
(
N
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales
,
g_idx
,
perm
,
workspace
,
M
,
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
perm1
:
torch
.
Tensor
,
perm2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation.
- perm2 (torch.Tensor): The second act_order input permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
==
torch
.
float16
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
topk
=
topk_ids
.
shape
[
1
]
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
None
,
override_config
=
override_config
,
is_marlin
=
True
,
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
"BLOCK_SIZE_M"
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
,
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
6cd5e5b0
...
@@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
...
@@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
return
None
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
def
get_default_config
(
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
)
->
Dict
[
str
,
int
]:
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
)
->
Dict
[
str
,
int
]:
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
'GROUP_SIZE_M'
:
8
}
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_M'
:
16
,
...
@@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
...
@@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
def
try_get_optimal_moe_config
(
w2_shape
:
Tuple
[
int
,
...],
w1_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
w2_shape
:
Tuple
[
int
,
...],
dtype
:
Optional
[
str
],
top_k
:
int
,
M
:
int
,
dtype
:
Optional
[
str
],
override_config
:
Optional
[
Dict
[
str
,
M
:
int
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
):
is_marlin
:
bool
=
False
,
):
if
override_config
:
if
override_config
:
config
=
override_config
config
=
override_config
else
:
else
:
...
@@ -391,6 +399,7 @@ def fused_topk(
...
@@ -391,6 +399,7 @@ def fused_topk(
topk
,
topk
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
ops
.
topk_softmax
(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
...
@@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
rand_perm1
:
torch
.
Tensor
,
rand_perm2
:
torch
.
Tensor
,
topk
:
int
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
renormalize
:
bool
=
True
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
#TODO fp8 is not implemented yet
assert
not
use_fp8
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
if
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
rand_perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
rand_perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
6cd5e5b0
...
@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
...
@@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal.
# Input scales can be loaded directly and should be equal.
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
def
_load_g_idx
(
self
,
shard_id
:
str
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
else
:
assert
shard_id
in
(
"w1"
,
"w3"
)
expert_data
.
copy_
(
loaded_weight
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
# compressed-tensors represents weights on disk which are flipped
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsMoEMethod"
)
else
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
f
"got
{
shard_id
}
."
)
...
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
...
@@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed:
whether or not the parameter is transposed on disk
# is_transposed:
if the dim to shard the weight
#
If transposed, the loaded weight will be transposed and the dim
#
should be flipped. Required by GPTQ, compressed-tensors
#
to shard the loaded weight will be flipped.
#
should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
is_transposed
:
if
is_transposed
:
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
shard_dim
=
~
shard_dim
shard_dim
=
~
shard_dim
# Case weight_scales
# Case input scale: input_scale loading is only supported for fp8
if
"weight_scale"
in
weight_name
:
if
"input_scale"
in
weight_name
:
# load the weight scaling based on the quantization scheme
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
# supported weight scales can be found in
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case g_idx
if
"g_idx"
in
weight_name
:
self
.
_load_g_idx
(
shard_dim
=
0
,
shard_id
=
shard_id
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
return
# Case weight scales and zero_points
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
# specific to each case
...
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
return
# Case weight_shape
if
"weight_shape"
in
weight_name
:
if
"weight_shape"
in
weight_name
:
self
.
_load_single_value
(
param
=
param
,
# only required by compressed-tensors
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case input scale
if
"input_scale"
in
weight_name
:
# Note: input_scale loading is only supported for fp8
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
self
.
_load_single_value
(
param
=
param
,
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
expert_id
=
expert_id
)
...
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
...
@@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
param_data
[
expert_id
][
idx
]
=
loaded_weight
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# If we are in the row parallel case (down_proj)
else
:
else
:
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
\ No newline at end of file
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
6cd5e5b0
...
@@ -5,9 +5,7 @@ from typing import Callable, List, Optional
...
@@ -5,9 +5,7 @@ from typing import Callable, List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
CompressionFormat
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if
not
(
self
.
quant_config
.
quant_format
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
and
self
.
num_bits
==
4
):
raise
ValueError
(
"For Fused MoE layers, only "
,
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
"is supported for 4 bits"
)
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
...
@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_
marlin_
moe
import
(
fused_marlin_moe
)
fused_marlin_moe
)
return
fused_marlin_moe
(
x
,
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
layer
.
w13_weight_packed
,
hidden_states
=
x
,
layer
.
w2_weight_packed
,
router_logits
=
router_logits
,
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
layer
.
w13_g_idx
,
top_k
=
top_k
,
layer
.
w2_g_idx
,
renormalize
=
renormalize
,
layer
.
w13_g_idx_sort_indices
,
topk_group
=
topk_group
,
layer
.
w2_g_idx_sort_indices
,
num_expert_group
=
num_expert_group
,
top_k
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
,
renormalize
=
renormalize
,
return
fused_marlin_moe
(
w1_scale
=
layer
.
w13_weight_scale
,
x
,
w2_scale
=
layer
.
w2_weight_scale
)
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_ids
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
6cd5e5b0
...
@@ -22,7 +22,7 @@ from vllm.scalar_type import scalar_types
...
@@ -22,7 +22,7 @@ from vllm.scalar_type import scalar_types
__all__
=
[
"CompressedTensorsWNA16"
]
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_TYPES_MAP
=
{
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
8
:
scalar_types
.
uint8b128
}
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
6cd5e5b0
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
...
@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
(
8
,
True
):
scalar_types
.
uint8b128
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
def
__init__
(
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
# (since we have only one group per output channel)
...
@@ -105,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -105,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference"
)
" faster inference"
)
return
None
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
prefix
:
str
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
if
(
isinstance
(
layer
,
LinearBase
)
or
)
->
Optional
[
Union
[
"GPTQMarlinLinearMethod"
,
"GPTQMarlinMoEMethod"
]]:
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
GPTQMarlinLinearMethod
(
self
)
return
GPTQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition
=
output_size_per_partition
,
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
input_size
=
input_size
,
group_size
=
group_size
)
group_size
=
group_size
,
)
# Determine sharding
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
...
@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm
=
layer
.
g_idx_sort_indices
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
# Permute scales from autogptq format to marlin format.
...
@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
layer
.
input_size_per_partition
),
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
def
apply
(
def
apply
(
...
@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition
=
layer
.
output_size_per_partition
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
is_k_full
=
layer
.
is_k_full
,
bias
=
bias
)
bias
=
bias
,
)
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
"""MoE Marlin method with quantization."""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
if
self
.
quant_config
.
group_size
!=
-
1
:
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
scales_size2
=
intermediate_size
//
self
.
quant_config
.
group_size
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
else
:
scales_size13
=
1
scales_size2
=
1
strategy
=
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
extra_weight_attrs
.
update
({
"quant_method"
:
strategy
,
"is_transposed"
:
True
})
# Fused gate_up_proj (column parallel)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
2
*
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
//
self
.
quant_config
.
pack_factor
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
# up_proj scales
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
# down_proj scales
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# up_proj scales
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
# down_proj scales
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
w13_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
w13_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
for
e
in
range
(
num_experts
):
w13_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w13_g_idx
[
e
]).
to
(
torch
.
int32
)
w2_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w2_g_idx
[
e
]).
to
(
torch
.
int32
)
w13_sorted_g_idx
[
e
]
=
layer
.
w13_g_idx
[
e
][
w13_g_idx_sort_indices
[
e
]]
w2_sorted_g_idx
[
e
]
=
layer
.
w2_g_idx
[
e
][
w2_g_idx_sort_indices
[
e
]]
replace_tensor
(
layer
,
"w13_g_idx"
,
w13_sorted_g_idx
)
replace_tensor
(
layer
,
"w2_g_idx"
,
w2_sorted_g_idx
)
replace_tensor
(
layer
,
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
replace_tensor
(
layer
,
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
# Repack weights
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w13_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w2_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_ids
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
).
to
(
orig_dtype
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
6cd5e5b0
...
@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
...
@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return
s
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
)
return
output
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# Permute zero-points in a similar way to scales, but do not use the
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
6cd5e5b0
"""Utility functions used for tests and benchmarks"""
"""Utility functions used for tests and benchmarks"""
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
...
@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
return
perm
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
def
marlin_quantize
(
w
:
torch
.
Tensor
,
act_order
:
bool
):
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
num_bits
=
quant_type
.
size_bits
...
@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
...
@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
# Quantize (and apply act_order if provided)
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w
,
quant_type
,
group_size
,
act_order
)
w
,
quant_type
,
group_size
,
act_order
,
test_perm
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
# increasing
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
6cd5e5b0
"""This file is used for /tests and /benchmarks"""
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
import
numpy
import
torch
import
torch
...
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
...
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
return
32
//
num_bits
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
q_w
.
shape
==
w_ref
.
shape
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
...
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
...
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx
[
i
]
=
i
//
group_size
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
rand_perm
=
test_perm
if
test_perm
is
not
None
else
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
...
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
...
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
)
)
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
group_size
:
int
,
act_order
:
bool
):
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
size_k
,
_
=
w
.
shape
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
w
.
is_floating_point
(),
"w must be float"
...
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
...
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
group_size
,
size_k
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
,
test_perm
)
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
...
...
vllm/model_executor/model_loader/utils.py
View file @
6cd5e5b0
...
@@ -24,10 +24,18 @@ def get_model_architecture(
...
@@ -24,10 +24,18 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
# for gptq_marlin, only run fused MoE for int4
if
model_config
.
quantization
==
"gptq_marlin"
:
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
and
hf_quant_config
.
get
(
"bits"
)
==
4
:
mixtral_supported
.
append
(
"gptq_marlin"
)
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
model_config
.
quantization
not
in
mixtral_supported
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/models/mixtral.py
View file @
6cd5e5b0
...
@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment