Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7fc23be8
Unverified
Commit
7fc23be8
authored
Aug 16, 2024
by
Mor Zusman
Committed by
GitHub
Aug 16, 2024
Browse files
[Kernel] W8A16 Int8 inside FusedMoE (#7415)
parent
e837b624
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
412 additions
and
136 deletions
+412
-136
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+70
-38
tests/models/test_jamba.py
tests/models/test_jamba.py
+7
-6
tests/quantization/test_experts_int8.py
tests/quantization/test_experts_int8.py
+28
-0
vllm/config.py
vllm/config.py
+2
-1
vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+0
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+91
-53
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+175
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+35
-37
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
7fc23be8
...
...
@@ -30,11 +30,28 @@ def benchmark_config(
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8
_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int8_w8a16
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
...
...
@@ -52,7 +69,11 @@ def benchmark_config(
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_fp8
:
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
((
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_fp8_w8a8
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
...
...
@@ -76,7 +97,8 @@ def benchmark_config(
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8
=
use_fp8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
...
...
@@ -155,11 +177,13 @@ class BenchmarkWorker:
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
torch
.
cuda
.
manual_seed_all
(
self
.
seed
)
dtype_str
=
"float8"
if
use_fp8
else
None
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
...
...
@@ -173,7 +197,8 @@ class BenchmarkWorker:
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
return
config
,
kernel_time
def
tune
(
...
...
@@ -184,9 +209,10 @@ class BenchmarkWorker:
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
search_space
:
List
[
BenchmarkConfig
],
)
->
BenchmarkConfig
:
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
List
[
Dict
[
str
,
int
]],
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_time
=
float
(
"inf"
)
for
config
in
tqdm
(
search_space
):
...
...
@@ -198,7 +224,8 @@ class BenchmarkWorker:
hidden_size
,
topk
,
dtype
,
use_fp8
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
10
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
...
...
@@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
}
def
save_configs
(
configs
:
Dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
dtype_str
=
"float8"
if
use_fp8
else
None
def
save_configs
(
configs
:
Dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
...
...
@@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
# Default: Mixtral.
E
=
config
.
num_local_experts
...
...
@@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
hidden_size
=
config
.
hidden_size
dtype
=
config
.
torch_dtype
use_fp8
=
args
.
dtype
==
"fp8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
if
args
.
batch_size
is
None
:
batch_sizes
=
[
...
...
@@ -294,20 +326,20 @@ def main(args: argparse.Namespace):
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
,
search_space
)
topk
,
dtype
,
use_fp8
_w8a8
,
use_int8_w8a16
,
search_space
)
for
batch_size
in
batch_sizes
])
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
topk
,
dtype
,
use_fp8
_w8a8
,
use_int8_w8a16
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
_w8a8
,
use_int8_w8a16
)
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
...
...
@@ -323,7 +355,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
choices
=
[
"auto"
,
"fp8
_w8a8"
,
"int8_w8a16
"
],
default
=
"auto"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
...
...
tests/models/test_jamba.py
View file @
7fc23be8
...
...
@@ -6,9 +6,12 @@ from vllm.worker.model_runner import _get_graph_batch_size
MODELS
=
[
"ai21labs/Jamba-tiny-random"
]
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
# TODO: Fix this with trained model
@
pytest
.
mark
.
skip
()
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
2
0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
b
float
16
"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
1
0
])
def
test_models
(
hf_runner
,
vllm_runner
,
...
...
@@ -17,8 +20,6 @@ def test_models(
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
# To pass the small model tests, we need full precision.
assert
dtype
==
"float"
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
...
...
@@ -36,8 +37,8 @@ def test_models(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
float
"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
20
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
half
"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
def
test_batching
(
vllm_runner
,
example_prompts
,
...
...
tests/quantization/test_experts_int8.py
0 → 100644
View file @
7fc23be8
# flake8: noqa
"""Tests experts_int8 quantization startup and generation,
doesn't test correctness
"""
import
pytest
from
tests.quantization.utils
import
is_quant_method_supported
MODELS
=
[
"ai21labs/Jamba-tiny-random"
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"experts_int8"
),
reason
=
"ExpertsInt8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
10
])
def
test_model_experts_int8_startup
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
quantization
=
"experts_int8"
)
as
vllm_model
:
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm/config.py
View file @
7fc23be8
...
...
@@ -243,7 +243,8 @@ class ModelConfig:
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"fp8"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
]
tpu_supported_quantization
=
[
"tpu_int8"
]
if
self
.
quantization
is
not
None
:
...
...
vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
7fc23be8
File moved
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
7fc23be8
File moved
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
7fc23be8
File moved
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
7fc23be8
File moved
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
7fc23be8
File moved
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
loat
8.json
→
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=f
p8_w8a
8.json
View file @
7fc23be8
File moved
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
7fc23be8
...
...
@@ -43,6 +43,8 @@ def fused_moe_kernel(
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsn
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
...
@@ -51,8 +53,8 @@ def fused_moe_kernel(
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8
:
tl
.
constexpr
,
):
use_fp8
_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
...
...
@@ -113,8 +115,12 @@ def fused_moe_kernel(
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_int8_w8a16
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8
:
if
use_fp8
_w8a8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
...
...
@@ -136,7 +142,9 @@ def fused_moe_kernel(
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_fp8
:
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
...
...
@@ -149,8 +157,9 @@ def fused_moe_kernel(
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
...
...
@@ -229,16 +238,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
)
->
None
:
use_fp8
_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
not
use_fp8
:
assert
A_scale
is
None
assert
B_scale
is
None
else
:
if
use_fp8_w8a8
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
elif
use_int8_w8a16
:
assert
B_scale
is
not
None
else
:
assert
A_scale
is
None
assert
B_scale
is
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
...
...
@@ -264,10 +275,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
...
...
@@ -426,6 +440,20 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
,
topk_ids
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return
"float32"
return
None
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -433,7 +461,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -454,13 +483,16 @@ def fused_experts(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
Non
e
,
config_dtyp
e
,
override_config
=
override_config
,
)
...
...
@@ -524,7 +556,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
...
@@ -542,7 +575,8 @@ def fused_experts(hidden_states: torch.Tensor,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
...
...
@@ -562,7 +596,8 @@ def fused_moe(
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -588,7 +623,9 @@ def fused_moe(
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
...
...
@@ -617,7 +654,8 @@ def fused_moe(
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8
=
use_fp8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
7fc23be8
...
...
@@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
)
from
vllm.model_executor.layers.quantization.experts_int8
import
(
ExpertsInt8Config
)
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
...
...
@@ -43,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
}
...
...
vllm/model_executor/layers/quantization/experts_int8.py
0 → 100644
View file @
7fc23be8
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
ExpertsInt8Config
(
QuantizationConfig
):
"""Config class for Int8 experts quantization."""
def
__init__
(
self
)
->
None
:
pass
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"experts_int8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ExpertsInt8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
ExpertsInt8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ExpertsInt8MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
ExpertsInt8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
int8_dtype
=
torch
.
int8
assert
'weight_loader'
in
extra_weight_attrs
weight_loader
=
extra_weight_attrs
[
'weight_loader'
]
wrapped_weight_loader
=
ExpertsInt8MoEMethod
.
quantizing_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
'weight_loader'
]
=
wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
int8_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
int8_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int8_w8a16
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
@
staticmethod
def
quantizing_weight_loader
(
layer
,
weight_loader
):
def
quantize_and_call_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
layer
.
intermediate_size_per_partition
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
device
=
get_tp_group
().
device
loaded_weight
=
loaded_weight
.
to
(
device
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
"w1"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[
shard
,
:])
layer
.
w13_scale
.
data
[
expert_id
,
0
:
shard_size
].
copy_
(
scales
[:,
0
])
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
"w3"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[
shard
,
:])
layer
.
w13_scale
.
data
[
expert_id
,
shard_size
:
2
*
shard_size
].
copy_
(
scales
[:,
0
])
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
"w2"
:
scales
=
quantize_in_place_and_get_scales
(
loaded_weight
[:,
shard
])
layer
.
w2_scale
.
data
[
expert_id
,
:].
copy_
(
scales
[:,
0
])
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
quantize_and_call_weight_loader
def
quantize_in_place_and_get_scales
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
vmax
=
torch
.
iinfo
(
torch
.
int8
).
max
scales
=
(
torch
.
max
(
torch
.
abs
(
weight
),
dim
=
1
,
keepdim
=
True
)[
0
]
/
vmax
)
weight
.
div_
(
scales
)
weight
.
round_
()
weight
.
clamp_
(
-
vmax
,
vmax
)
return
scales
vllm/model_executor/layers/quantization/fp8.py
View file @
7fc23be8
...
...
@@ -488,7 +488,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8
=
True
,
use_fp8
_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
...
...
vllm/model_executor/models/jamba.py
View file @
7fc23be8
...
...
@@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
SchedulerConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return
hidden_states
class
JambaMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
JambaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
hidden_act
=
config
.
hidden_act
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
JambaMoE
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return
hidden_states
.
view
(
orig_shape
)
class
JambaMLP
(
JambaMoE
):
def
__init__
(
self
,
config
:
JambaConfig
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
(
config
,
num_experts
=
1
,
top_k
=
1
,
params_dtype
=
params_dtype
,
tp_size
=
tp_size
,
quant_config
=
quant_config
)
class
JambaMambaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -884,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
...
...
@@ -907,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
if
"feed_forward"
in
name
and
not
_is_moe_layer
(
name
):
## map MLP layers to expert with ID=0
name
=
name
.
replace
(
"feed_forward"
,
"feed_forward.experts.0"
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -921,16 +906,21 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
for
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
weight_
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
...
...
@@ -943,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
_is_moe_layer
(
name
:
str
):
return
any
(
[
experts_name
in
name
for
experts_name
in
[
"experts"
,
"router"
,
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment