Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aae74ef9
Unverified
Commit
aae74ef9
authored
Aug 21, 2024
by
Michael Goin
Committed by
GitHub
Aug 22, 2024
Browse files
Revert "[Kernel] Expand MoE weight loading + Add Fused Marlin MoE Kernel (#7527)" (#7764)
parent
cde9183b
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
84 additions
and
2374 deletions
+84
-2374
CMakeLists.txt
CMakeLists.txt
+1
-2
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+0
-1740
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+0
-12
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+0
-9
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+0
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-14
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+8
-6
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+18
-116
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+36
-170
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+0
-5
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-283
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+18
-11
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+2
-2
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+0
-1
No files found.
CMakeLists.txt
View file @
aae74ef9
...
@@ -286,8 +286,7 @@ define_gpu_extension_target(
...
@@ -286,8 +286,7 @@ define_gpu_extension_target(
set
(
VLLM_MOE_EXT_SRC
set
(
VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu"
)
"csrc/moe/marlin_moe_ops.cu"
)
define_gpu_extension_target
(
define_gpu_extension_target
(
_moe_C
_moe_C
...
...
csrc/moe/marlin_moe_ops.cu
deleted
100644 → 0
View file @
cde9183b
This diff is collapsed.
Click to expand it.
csrc/moe/marlin_moe_ops.h
deleted
100644 → 0
View file @
cde9183b
#pragma once
#include <torch/all.h>
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
\ No newline at end of file
csrc/moe/torch_bindings.cpp
View file @
aae74ef9
#include "core/registration.h"
#include "core/registration.h"
#include "moe_ops.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
// Apply topk softmax to the gating outputs.
...
@@ -8,14 +7,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -8,14 +7,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"
);
"token_expert_indices, Tensor gating_output) -> ()"
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
impl
(
"topk_softmax"
,
torch
::
kCUDA
,
&
topk_softmax
);
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
}
}
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
tests/weight_loading/models.txt
View file @
aae74ef9
...
@@ -13,7 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
...
@@ -13,7 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
awq, casperhansen/mixtral-instruct-awq, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
vllm/_custom_ops.py
View file @
aae74ef9
...
@@ -300,20 +300,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
...
@@ -300,20 +300,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
return
torch
.
ops
.
_C
.
awq_marlin_repack
(
b_q_weight
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_moe_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
num_experts
=
b_q_weight
.
shape
[
0
]
assert
size_k
%
16
==
0
output
=
torch
.
empty
((
num_experts
,
size_k
//
16
,
size_n
*
2
),
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
_C
.
gptq_marlin_repack
(
b_q_weight
[
e
],
perm
[
e
],
size_k
,
size_n
,
num_bits
)
return
output
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
aae74ef9
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoEMethodBase
)
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
]
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
]
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_
marlin_
moe
,
fused_
moe
,
fused_topk
,
fused_experts
,
fused_moe
,
fused_
topk
,
get_config_file_name
,
get_config_file_name
,
grouped_topk
)
grouped_topk
)
__all__
+=
[
__all__
+=
[
"fused_marlin_moe"
,
"fused_moe"
,
"fused_moe"
,
"fused_topk"
,
"fused_topk"
,
"fused_experts"
,
"fused_experts"
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
aae74ef9
...
@@ -323,16 +323,21 @@ def get_moe_configs(E: int, N: int,
...
@@ -323,16 +323,21 @@ def get_moe_configs(E: int, N: int,
return
None
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
def
get_default_config
(
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
)
->
Dict
[
str
,
int
]:
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
)
->
Dict
[
str
,
int
]:
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
'GROUP_SIZE_M'
:
8
}
}
if
M
<=
E
or
(
is_marlin
and
M
<=
32
)
:
if
M
<=
E
:
config
=
{
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_N'
:
32
,
...
@@ -342,14 +347,14 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
...
@@ -342,14 +347,14 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
def
try_get_optimal_moe_config
(
w
2
_shape
:
Tuple
[
int
,
...],
w
1
_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
w2_shape
:
Tuple
[
int
,
...]
,
dtype
:
Optional
[
str
]
,
top_k
:
int
,
M
:
int
,
dtype
:
Optional
[
str
]
,
override_config
:
Optional
[
Dict
[
str
,
M
:
int
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
is_marlin
:
bool
=
False
):
):
if
override_config
:
if
override_config
:
config
=
override_config
config
=
override_config
else
:
else
:
...
@@ -363,8 +368,7 @@ def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
...
@@ -363,8 +368,7 @@ def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
else
:
# Else use the default config
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
is_marlin
)
return
config
return
config
...
@@ -437,108 +441,6 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -437,108 +441,6 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
g_idx1
:
torch
.
Tensor
,
g_idx2
:
torch
.
Tensor
,
rand_perm1
:
torch
.
Tensor
,
rand_perm2
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
1
]
*
16
,
"Hidden size mismatch w1"
assert
hidden_states
.
shape
[
1
]
==
w2
.
shape
[
2
]
//
2
,
"Hidden size mismatch w2"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
#TODO fp8 is not implemented yet
assert
not
use_fp8
M
,
K
=
hidden_states
.
shape
E
=
w1
.
shape
[
0
]
N
=
w2
.
shape
[
1
]
*
16
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
is_marlin
=
True
)
config
=
get_config_func
(
M
)
block_size_m
=
config
[
'BLOCK_SIZE_M'
]
sorted_token_ids
,
_
,
_
=
moe_align_block_size
(
topk_ids
,
block_size_m
,
E
)
max_workspace_size
=
((
M
+
255
)
//
256
)
*
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
requires_grad
=
False
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
hidden_states
,
w1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w1_scale
,
g_idx1
,
rand_perm1
,
workspace
,
M
,
2
*
N
,
K
,
True
,
E
,
topk
,
block_size_m
,
True
,
False
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
w2
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
w2_scale
,
g_idx2
,
rand_perm2
,
workspace
,
M
,
K
,
N
,
True
,
E
,
topk
,
block_size_m
,
False
,
True
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
aae74ef9
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
enum
import
Enum
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
FusedMoeWeightScaleSupported
(
Enum
):
TENSOR
=
"tensor"
CHANNEL
=
"channel"
GROUP
=
"group"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
@
abstractmethod
@
abstractmethod
...
@@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
...
@@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
)
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_id
:
int
):
param_data
=
param
.
data
# for per tensor weight quantization
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
elif
shard_id
==
"w2"
:
param_data
[
expert_id
]
=
loaded_weight
def
_load_model_weight_or_group_weight_scale
(
self
,
shard_dim
:
int
,
expert_data
:
torch
.
Tensor
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# Load grouped weight scales for group quantization
# or model weights
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
def
_load_per_channel_weight_scale
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# for per channel weight quantization
if
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
def
_load_w13
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
# w3, up_proj: Load into second logical weight of w13.
else
:
assert
shard_id
==
"w3"
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
def
_load_w2
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data
.
copy_
(
loaded_weight
)
def
_load_single_value
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_id
:
int
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
param_data
[
expert_id
]
=
loaded_weight
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
f
"got
{
shard_id
}
."
)
WEIGHT_SCALE_SUPPORTED
=
[
# Special case for fp8 scales.
e
.
value
for
e
in
FusedMoeWeightScaleSupported
if
getattr
(
param
,
"is_fp8_scale"
,
False
):
]
self
.
_load_fp8_scale
(
param
.
data
,
loaded_weight
,
weight_name
,
# Fetch the dim to shard the parameter/loaded weight
shard_id
,
expert_id
)
# based on the shard id. This will be whatever
return
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
expert_data
=
param
.
data
[
expert_id
]
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed: whether or not the parameter is transposed on disk
# If transposed, weight is saved as [input_dim, output_dim]
# If transposed, the loaded weight will be transposed and the dim
# Otherwise, weight is saved as [output_dim, input_dim]
# to shard the loaded weight will be flipped.
# Default is not transposed/input dim is dim 1
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
input_dim
=
getattr
(
param
,
"input_dim"
,
1
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
output_dim
=
getattr
(
param
,
"output_dim"
,
0
)
if
is_transposed
:
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
shard_dim
=
~
shard_dim
# Case weight_scales
if
"weight_scale"
in
weight_name
:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
self
.
_load_per_channel_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
GROUP
.
value
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
else
:
raise
ValueError
(
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
if
"weight_shape"
in
weight_name
:
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case input scale
# Index the loaded weight for tp sharding.
if
"input_scale"
in
weight_name
:
# down_proj: "RowParallel" so tp sharding on input_dim
# Note: input_scale loading is only supported for fp8
if
shard_id
==
"w2"
:
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
shard_dim
=
input_dim
loaded_weight
).
abs
()
>
1e-5
:
shard_size
=
expert_data
.
shape
[
shard_dim
]
raise
ValueError
(
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
"input_scales of w1 and w3 of a layer "
elif
shard_id
in
(
"w1"
,
"w3"
):
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
shard_dim
=
output_dim
f
"vs.
{
loaded_weight
}
"
)
shard_size
=
expert_data
.
shape
[
output_dim
]
//
2
offset
=
shard_size
*
tp_rank
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
offset
,
shard_size
)
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case model weights
# Narrow parameter and load.
if
"weight"
in
weight_name
:
# w1, gate_proj: Load into first logical weight of w13.
self
.
_load_model_weight_or_group_weight_scale
(
if
shard_id
==
"w1"
:
shard_id
=
shard_id
,
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
shard_dim
=
shard_dim
,
expert_data
.
copy_
(
loaded_weight
)
loaded_weight
=
loaded_weight
,
# w3, up_proj: Load into second logical weight of w13.
expert_data
=
expert_data
,
elif
shard_id
==
"w3"
:
tp_rank
=
tp_rank
)
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
return
expert_data
.
copy_
(
loaded_weight
)
# w2, down_proj: Load into only logical weight of w2.
elif
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
else
:
raise
ValueError
(
f
"Expected shard_id w1,w2 or w3 but got
{
shard_id
}
"
)
@
staticmethod
@
staticmethod
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
aae74ef9
...
@@ -3,12 +3,9 @@ from typing import Any, Dict, List, Optional
...
@@ -3,12 +3,9 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
...
@@ -67,8 +64,6 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -67,8 +64,6 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
(
self
)
return
None
return
None
@
classmethod
@
classmethod
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
deleted
100644 → 0
View file @
cde9183b
import
enum
from
enum
import
Enum
from
typing
import
List
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
__all__
=
[
"CompressedTensorsMoEMethod"
]
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
num_bits
=
config
.
num_bits
self
.
packed_factor
=
32
//
config
.
num_bits
self
.
strategy
=
config
.
strategy
.
value
self
.
group_size
=
config
.
group_size
assert
config
.
symmetric
,
(
"Only symmetric quantization is supported for MoE"
)
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
self
.
strategy
})
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
packed_factor
,
2
*
intermediate_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_packed"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
//
self
.
packed_factor
,
hidden_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_packed"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
strategy
==
"channel"
:
num_groups_w2
=
num_groups_w13
=
1
self
.
group_size
=
-
1
else
:
num_groups_w2
=
intermediate_size
//
self
.
group_size
num_groups_w13
=
hidden_size
//
self
.
group_size
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w13
,
2
*
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_scale
)
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_scale
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
w2_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_shape"
,
w2_weight_shape
)
set_weight_attrs
(
w2_weight_shape
,
extra_weight_attrs
)
w13_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_shape"
,
w13_weight_shape
)
set_weight_attrs
(
w13_weight_shape
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
((
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
,
num_bits
)
return
output
size_k2
=
layer
.
w2_weight_packed
.
shape
[
2
]
size_k13
=
layer
.
w13_weight_packed
.
shape
[
2
]
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w13_weight_packed
.
shape
[
2
],
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w2_weight_packed
.
shape
[
2
],
self
.
num_bits
,
)
replace_tensor
(
"w2_weight_packed"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
layer
.
w13_weight_scale
,
size_k13
,
layer
.
w13_weight_scale
.
shape
[
2
],
self
.
group_size
,
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_scale"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
.
shape
[
1
]
*
self
.
packed_factor
,
size_k2
,
self
.
group_size
,
self
.
num_bits
,
)
replace_tensor
(
"w2_weight_scale"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_marlin_moe
)
return
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
top_k
,
renormalize
=
renormalize
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
)
vllm/model_executor/layers/quantization/fp8.py
View file @
aae74ef9
...
@@ -7,8 +7,7 @@ from torch.nn.parameter import Parameter
...
@@ -7,8 +7,7 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -319,16 +318,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -319,16 +318,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
{
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w2_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
# INPUT_SCALES
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
@@ -341,14 +343,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -341,14 +343,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_experts
,
dtype
=
torch
.
float32
),
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
else
:
else
:
layer
.
w13_input_scale
=
None
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
...
...
vllm/model_executor/model_loader/utils.py
View file @
aae74ef9
...
@@ -23,11 +23,11 @@ def get_model_architecture(
...
@@ -23,11 +23,11 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
model_config
.
quantization
!=
"fp8"
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/models/jamba.py
View file @
aae74ef9
...
@@ -920,7 +920,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -920,7 +920,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
loaded_weight
,
loaded_weight
,
name
,
weight_
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
expert_id
=
expert_id
)
break
break
...
...
vllm/model_executor/models/mixtral.py
View file @
aae74ef9
...
@@ -73,7 +73,6 @@ class MixtralMoE(nn.Module):
...
@@ -73,7 +73,6 @@ class MixtralMoE(nn.Module):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
# Gate always runs at half / full precision for now.
# Gate always runs at half / full precision for now.
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
num_experts
,
bias
=
False
,
bias
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment