Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2480 additions
and
448 deletions
+2480
-448
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+3
-0
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json
.../configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json
+200
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+153
-0
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+294
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+190
-335
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+66
-58
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+251
-0
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+158
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+48
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+5
-2
vllm/model_executor/layers/lightning_attn.py
vllm/model_executor/layers/lightning_attn.py
+651
-0
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+5
-0
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+45
-16
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+23
-8
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+252
-15
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+2
-0
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+15
-12
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+2
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+109
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+8
-2
No files found.
vllm/model_executor/layers/fused_moe/__init__.py
View file @
fcfc474d
...
...
@@ -35,6 +35,8 @@ if HAS_TRITON:
# import to register the custom ops
import
vllm.model_executor.layers.fused_moe.fused_marlin_moe
# noqa
import
vllm.model_executor.layers.fused_moe.fused_moe
# noqa
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
...
...
@@ -45,4 +47,5 @@ if HAS_TRITON:
"fused_experts"
,
"get_config_file_name"
,
"grouped_topk"
,
"cutlass_moe_fp8"
,
]
vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json
0 → 100644
View file @
fcfc474d
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
}
vllm/model_executor/layers/fused_moe/cutlass_moe.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def
cutlass_moe_fp8
(
a
:
torch
.
Tensor
,
w1_q
:
torch
.
Tensor
,
w2_q
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
ab_strides1
:
torch
.
Tensor
,
c_strides1
:
torch
.
Tensor
,
ab_strides2
:
torch
.
Tensor
,
c_strides2
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
out_dtype
:
torch
.
dtype
=
torch
.
half
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_q
.
dtype
==
torch
.
float8_e4m3fn
assert
w2_q
.
dtype
==
torch
.
float8_e4m3fn
assert
a
.
shape
[
1
]
==
w1_q
.
shape
[
1
],
"Hidden size mismatch w1"
assert
w1_q
.
shape
[
2
]
==
w2_q
.
shape
[
1
]
*
2
,
"Hidden size mismatch w2"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
a
.
shape
[
0
],
"Input scale shape mismatch"
assert
w1_scale
.
dim
()
==
1
or
w1_scale
.
shape
[
1
]
==
1
or
w1_scale
.
shape
[
1
]
==
w1_q
.
shape
[
2
],
"W1 scale shape mismatch"
assert
w2_scale
.
dim
()
==
1
or
w2_scale
.
shape
[
1
]
==
1
or
w2_scale
.
shape
[
1
]
==
w2_q
.
shape
[
2
],
"W2 scale shape mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"Weights expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1_q
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
assert
ab_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"AB Strides 1 expert number mismatch"
assert
c_strides1
.
shape
[
0
]
==
w1_q
.
shape
[
0
],
"C Strides 1 expert number mismatch"
assert
ab_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"AB Strides 2 expert number mismatch"
assert
c_strides2
.
shape
[
0
]
==
w2_q
.
shape
[
0
],
"C Strides 2 expert number mismatch"
assert
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
num_experts
=
w1_q
.
size
(
0
)
m
=
a
.
size
(
0
)
k
=
w1_q
.
size
(
1
)
n
=
w2_q
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
if
apply_router_weight_on_input
:
assert
topk
==
1
,
\
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a
=
a
*
topk_weights
.
to
(
out_dtype
)
a_q
,
a1_scale
=
ops
.
scaled_fp8_quant
(
a
,
a1_scale
,
use_per_token_if_dynamic
=
per_act_token
)
device
=
a_q
.
device
expert_offsets
=
torch
.
empty
((
num_experts
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
((
num_experts
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
num_experts
,
n
,
k
)
rep_a_q
=
a_q
.
view
(
dtype
=
torch
.
uint8
)[
a_map
].
view
(
dtype
=
a_q
.
dtype
)
rep_a1_scales
=
a1_scale
[
a_map
]
if
per_act_token
else
a1_scale
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
ops
.
cutlass_moe_mm
(
c1
,
rep_a_q
,
w1_q
,
rep_a1_scales
,
w1_scale
,
expert_offsets
[:
-
1
],
problem_sizes1
,
ab_strides1
,
ab_strides1
,
c_strides1
)
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate
,
c1
)
intemediate_q
,
a2_scale
=
ops
.
scaled_fp8_quant
(
intermediate
,
a2_scale
,
use_per_token_if_dynamic
=
per_act_token
)
ops
.
cutlass_moe_mm
(
c2
,
intemediate_q
,
w2_q
,
a2_scale
,
w2_scale
,
expert_offsets
[:
-
1
],
problem_sizes2
,
ab_strides2
,
ab_strides2
,
c_strides2
)
# Gather tokens
c2
=
c2
[
c_map
].
view
(
m
,
topk
,
k
)
if
not
apply_router_weight_on_input
:
c2
=
c2
*
topk_weights
.
view
(
m
,
topk
,
1
).
to
(
out_dtype
)
return
c2
.
sum
(
dim
=
1
)
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
importlib.util
from
typing
import
Optional
,
Tuple
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_fp8_perm
,
_fp8_quantize
,
_resize_cache
)
from
vllm.utils
import
round_up
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
def
_valid_deep_gemm
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if
not
has_deep_gemm
:
return
False
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
# Expert maps not supported yet.
if
expert_map
is
not
None
:
return
False
align
=
dg
.
get_m_alignment_for_contiguous_layout
()
M
=
hidden_states
.
shape
[
0
]
_
,
K
,
N
=
w2
.
shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if
N
<=
512
:
return
False
if
align
>
M
or
N
%
align
!=
0
or
K
%
align
!=
0
:
return
False
return
(
hidden_states
.
is_contiguous
()
and
w1
.
is_contiguous
()
and
w2
.
is_contiguous
())
def
_moe_permute
(
curr_hidden_states
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
],
curr_topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
block_m
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num
=
curr_topk_ids
.
shape
[
1
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
block_m
,
global_num_experts
,
expert_map
,
pad_sorted_ids
=
True
))
inv_perm
:
Optional
[
torch
.
Tensor
]
=
None
num_tokens
=
top_k_num
*
tokens_in_chunk
sorted_token_ids
=
sorted_token_ids
.
clamp
(
max
=
num_tokens
-
1
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids
,
block_m
,
dim
=
0
)
inv_perm
=
torch
.
argsort
(
sorted_token_ids
)[:
num_tokens
]
# Permute according to sorted token ids.
curr_hidden_states
=
_fp8_perm
(
curr_hidden_states
,
sorted_token_ids
//
top_k_num
)
if
a1q_scale
is
not
None
:
a1q_scale
=
a1q_scale
[
sorted_token_ids
//
top_k_num
]
return
(
curr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
def
_moe_unpermute_and_reduce
(
out
:
torch
.
Tensor
,
curr_hidden
:
torch
.
Tensor
,
inv_perm
:
Optional
[
torch
.
Tensor
],
topk_weight
:
torch
.
Tensor
,
)
->
None
:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M
,
topk
=
topk_weight
.
shape
K
=
curr_hidden
.
shape
[
1
]
curr_hidden
=
curr_hidden
[
inv_perm
,
...]
curr_hidden
=
curr_hidden
.
view
(
-
1
,
topk
,
K
)
curr_hidden
.
mul_
(
topk_weight
.
view
(
M
,
-
1
,
1
))
ops
.
moe_sum
(
curr_hidden
,
out
)
def
deep_gemm_moe_fp8
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
assert
expert_map
is
None
,
"Expert maps not supported yet"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
w1
.
dtype
==
torch
.
float8_e4m3fn
assert
w2
.
dtype
==
torch
.
float8_e4m3fn
assert
w1
.
shape
[
0
]
==
w2
.
shape
[
0
],
"Expert number mismatch"
assert
w1
.
shape
[
0
]
==
w1_scale
.
shape
[
0
],
"w1 scales expert number mismatch"
assert
w1
.
shape
[
0
]
==
w2_scale
.
shape
[
0
],
"w2 scales expert number mismatch"
assert
a1_scale
is
None
or
a1_scale
.
dim
(
)
==
0
or
a1_scale
.
shape
[
0
]
==
1
or
a1_scale
.
shape
[
0
]
==
hidden_states
.
shape
[
0
],
"Input scale shape mismatch"
assert
a2_scale
is
None
or
a1_scale
is
None
or
a2_scale
.
shape
==
a1_scale
.
shape
,
"Intermediate scale shape mismatch"
# noqa: E501
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
assert
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
block_m
=
dg
.
get_m_alignment_for_contiguous_layout
()
block_shape
=
[
block_m
,
block_m
]
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w1_scale
).
contiguous
()
w2_scale
=
dg
.
get_col_major_tma_aligned_tensor
(
w2_scale
).
contiguous
()
M_sum
=
topk_ids
.
numel
()
+
global_num_experts
*
(
block_m
-
1
)
M_sum
=
round_up
(
M_sum
,
block_m
)
num_chunks
=
(
num_tokens
//
CHUNK_SIZE
)
+
1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13
=
torch
.
empty
(
M_sum
*
max
(
N
,
K
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace1
=
workspace13
[:
M_sum
*
N
].
view
(
M_sum
,
N
)
workspace2
=
torch
.
empty
((
M_sum
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
workspace3
=
workspace13
[:
M_sum
*
K
].
view
(
M_sum
,
K
)
for
chunk
in
range
(
num_chunks
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
(
qcurr_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
)
=
_moe_permute
(
qcurr_hidden_states
,
a1q_scale
,
curr_topk_ids
,
global_num_experts
,
expert_map
,
block_m
)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
curr_M
=
sorted_token_ids
.
numel
()
workspace1
=
_resize_cache
(
workspace1
,
(
curr_M
,
N
))
workspace2
=
_resize_cache
(
workspace2
,
(
curr_M
,
N
//
2
))
workspace3
=
_resize_cache
(
workspace3
,
(
curr_M
,
K
))
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qcurr_hidden_states
,
a1q_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_ids
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
workspace2
,
workspace1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qworkspace2
,
a2q_scale
=
_fp8_quantize
(
workspace2
,
a2_scale
,
block_shape
)
dg
.
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
qworkspace2
,
a2q_scale
),
(
w2
,
w2_scale
),
workspace3
,
expert_ids
)
_moe_unpermute_and_reduce
(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
workspace3
.
view
(
*
workspace3
.
shape
),
inv_perm
,
curr_topk_weights
)
return
out_hidden_states
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
fcfc474d
...
...
@@ -12,14 +12,18 @@ import triton.language as tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.
quantization.utils.fp8_utils
import
(
per_token_group_quant
_fp8
)
from
vllm.model_executor.layers.
quantization.utils.int8_utils
import
(
per_token_group_quant_int8
)
from
vllm.model_executor.layers.
fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm
,
deep_gemm_moe
_fp8
)
from
vllm.model_executor.layers.
fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
rocm_aiter_fused_experts
,
rocm_aiter_topk_softmax
)
logger
=
init_logger
(
__name__
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
...
...
@@ -671,248 +675,13 @@ def fused_moe_kernel(
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,
)
tokens_cnts
=
torch
.
zeros
((
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,
)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
=
None
,
num_token
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
if
num_token
:
if
num_token
<
block_size
:
max_num_tokens_padded
=
min
(
topk_ids
.
numel
()
*
block_size
,
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
))
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
full
((
max_num_tokens_padded
,),
fill_value
=
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
or
num_experts
!=
256
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
# Currently requires num_experts=256
ops
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
B_zp
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
Optional
[
torch
.
Tensor
],
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
...
...
@@ -926,33 +695,24 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_int8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_shape
[
0
])
==
B_scale
.
shape
[
-
2
])
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_shape
[
1
])
==
B_scale
.
shape
[
-
1
])
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
...
...
@@ -960,6 +720,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
shape
[
0
]
num_tokens
=
M
*
top_k
EM
=
sorted_token_ids
.
shape
[
0
]
if
A
.
shape
[
0
]
<
config
[
"BLOCK_SIZE_M"
]:
# optimize for small batch_size.
...
...
@@ -977,20 +740,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
if
os
.
environ
.
get
(
'moe_wna16_use_cuda'
)
==
'1'
:
use_moe_wna16_cuda
=
should_moe_wna16_use_cuda
(
num_valid_tokens
=
topk_ids
.
numel
()
,
group_size
=
block_shape
[
1
],
num_experts
=
B
.
shape
[
0
],
bit
=
4
if
use_int4_w4a16
else
8
)
num_valid_tokens
=
num_tokens
,
group_size
=
block_shape
[
1
],
num_experts
=
B
.
shape
[
0
],
bit
=
4
if
use_int4_w4a16
else
8
)
config
=
config
.
copy
()
config
.
update
(
get_moe_wna16_block_config
(
config
=
config
,
use_moe_wna16_cuda
=
use_moe_wna16_cuda
,
num_valid_tokens
=
topk_ids
.
numel
()
,
num_valid_tokens
=
num_tokens
,
size_k
=
A
.
shape
[
1
],
size_n
=
B
.
shape
[
1
],
num_experts
=
B
.
shape
[
1
],
group_size
=
block_shape
[
1
],
real_top_k
=
topk
_ids
.
shape
[
1
]
,
real_top_k
=
top
_
k
,
block_size_m
=
config
[
"BLOCK_SIZE_M"
]))
if
use_moe_wna16_cuda
:
...
...
@@ -1055,7 +818,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
()
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
...
...
@@ -1079,7 +842,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
# config = config.copy()
# BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
...
...
@@ -1099,7 +861,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B
.
shape
[
1
]
if
not
use_nn_moe
else
B
.
shape
[
2
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
()
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
...
...
@@ -1352,12 +1114,34 @@ def try_get_optimal_moe_config(
return
config
def
vllm_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_indices
def
dispatch_topk_func
()
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
is_rocm_aiter_moe_enabled
():
return
rocm_aiter_topk_softmax
return
vllm_topk_softmax
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
...
...
@@ -1376,30 +1160,29 @@ def fused_topk(
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
gating_output_float
=
gating_output
.
float
()
# TODO(woosuk): Optimize this.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_func
=
dispatch_topk_func
()
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output_float
,
renormalize
)
del
token_expert_indicies
# Not used. Will be used in the future.
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
...
...
@@ -1448,11 +1231,12 @@ def grouped_topk(hidden_states: torch.Tensor,
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
)
->
Optional
[
str
]:
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a8
:
...
...
@@ -1474,6 +1258,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -1489,10 +1274,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_
int
8_w8a
16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
activation
,
apply_router_weight_on_input
,
use_
fp
8_w8a
8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
def
inplace_fused_experts_fake
(
...
...
@@ -1502,6 +1287,7 @@ def inplace_fused_experts_fake(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -1534,6 +1320,7 @@ def outplace_fused_experts(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -1549,10 +1336,11 @@ def outplace_fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
def
outplace_fused_experts_fake
(
...
...
@@ -1587,6 +1375,24 @@ direct_register_custom_op(
)
def
torch_vllm_inplace_fused_experts
(
**
kwargs
)
->
torch
.
Tensor
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
**
kwargs
)
hidden_states
=
kwargs
[
'hidden_states'
]
return
hidden_states
def
torch_vllm_outplace_fused_experts
(
**
kwargs
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
**
kwargs
)
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
is_rocm_aiter_moe_enabled
():
return
rocm_aiter_fused_experts
if
inplace
:
return
torch_vllm_inplace_fused_experts
return
torch_vllm_outplace_fused_experts
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -1594,6 +1400,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -1607,21 +1414,50 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
return
hidden_states
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
,
expert_map
)):
assert
apply_router_weight_on_input
is
False
return
deep_gemm_moe_fp8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -1631,6 +1467,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
...
...
@@ -1663,10 +1500,13 @@ def fused_experts_impl(hidden_states: torch.Tensor,
]
num_tokens
,
_
=
hidden_states
.
shape
if
use_nn_moe
:
E
,
_
,
N
=
w1
.
shape
else
:
E
,
N
,
_
=
w1
.
shape
K
=
w2
.
shape
[
1
]
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
shape
[
1
]
...
...
@@ -1695,13 +1535,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
(
M
,
topk_ids
.
shape
[
1
],
N
))
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
(
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
])].
view
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
]))
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
M
,
top_k_num
,
N
)
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
(
K
if
not
use_nn_moe
else
w2
.
shape
[
2
])].
view
(
M
,
top_k_num
,
K
)
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2
=
torch
.
empty
((
M
*
top_k_num
,
N
//
2
),
...
...
@@ -1745,6 +1583,16 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
a1q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qcurr_hidden_states
,
a1q_scale
=
_fp8_quantize
(
curr_hidden_states
,
a1_scale
,
block_shape
)
else
:
qcurr_hidden_states
=
curr_hidden_states
a1q_scale
=
a1_scale
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
...
...
@@ -1771,30 +1619,27 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_stages"
:
0
,
"num_warps"
:
4
}
# sorted_token_ids, expert_ids, num_tokens_post_padded = (
# moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
# global_num_experts, expert_map))
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
a1
q
_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
apply_router_weight_on_input
,
top_k_num
,
config
,
compute_type
=
compute_type
,
...
...
@@ -1813,6 +1658,16 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
if
use_fp8_w8a8
:
qintermediate_cache2
,
a2q_scale
=
_fp8_quantize
(
intermediate_cache2
,
a2_scale
,
block_shape
)
else
:
qintermediate_cache2
=
intermediate_cache2
a2q_scale
=
a2_scale
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
...
...
@@ -1840,18 +1695,17 @@ def fused_experts_impl(hidden_states: torch.Tensor,
"num_warps"
:
4
}
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
q
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
a2
q
_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
not
apply_router_weight_on_input
,
1
,
config
,
compute_type
=
compute_type
,
...
...
@@ -1864,6 +1718,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
fcfc474d
...
...
@@ -9,7 +9,7 @@ import torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
UninitializedParameter
from
vllm
import
envs
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
get_dp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
...
...
@@ -67,6 +67,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -135,6 +137,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
)
# Lazy import to avoid importing triton.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
if
is_rocm_aiter_moe_enabled
():
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
if
current_platform
.
is_cpu
():
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
...
...
@@ -162,24 +176,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
use_nn_moe
=
use_nn_moe
,)
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
)
def
forward_cuda
(
self
,
...
...
@@ -196,6 +213,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
...
...
@@ -211,16 +229,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
,)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_nn_moe
=
use_nn_moe
)
def
forward_cpu
(
self
,
...
...
@@ -238,10 +258,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
**
kwargs
,
):
assert
activation
==
"silu"
,
f
"
{
activation
}
is not supported."
assert
apply_router_weight_on_input
is
False
return
layer
.
ipex_fusion
(
x
,
use_grouped_topk
,
...
...
@@ -265,16 +287,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
assert
layer
is
not
None
assert
apply_router_weight_on_input
is
False
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for HPU."
)
...
...
@@ -299,12 +326,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
assert
custom_routing_function
is
None
assert
apply_router_weight_on_input
is
False
if
scoring_func
!=
"softmax"
:
raise
NotImplementedError
(
"Only softmax scoring function is supported for TPU."
)
...
...
@@ -321,7 +350,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map
=
expert_map
,
renormalize
=
renormalize
)
forward_native
=
forward_cuda
forward_native
=
forward_tpu
if
current_platform
.
is_tpu
()
else
forward_cuda
def
determine_expert_map
(
...
...
@@ -410,6 +439,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
):
super
().
__init__
()
...
...
@@ -484,7 +514,7 @@ class FusedMoE(torch.nn.Module):
"non-grouped topk."
)
if
current_platform
.
is_hpu
():
from
vllm_hpu_extension.ops
import
DynamicFusedMOE
self
.
hpu_fused_moe
=
DynamicFusedMOE
(
self
.
num_experts
)
self
.
hpu_fused_moe
=
DynamicFusedMOE
(
self
.
global_
num_experts
)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
...
...
@@ -500,7 +530,9 @@ class FusedMoE(torch.nn.Module):
self
.
use_nn_moe
=
int
(
os
.
environ
.
get
(
'MOE_NN'
,
1
))
==
1
else
:
self
.
use_nn_moe
=
False
self
.
apply_router_weight_on_input
=
apply_router_weight_on_input
moe_quant_params
=
{
"num_experts"
:
self
.
local_num_experts
,
"hidden_size"
:
hidden_size
,
...
...
@@ -736,8 +768,9 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
self
.
tp_rank
)
return
# Case weight scales and zero_points
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
):
# Case weight scales, zero_points and offset
if
(
"scale"
in
weight_name
or
"zero"
in
weight_name
or
"offset"
in
weight_name
):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
...
...
@@ -886,6 +919,7 @@ class FusedMoE(torch.nn.Module):
scoring_func
=
self
.
scoring_func
,
e_score_correction_bias
=
self
.
e_score_correction_bias
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
)
...
...
@@ -923,32 +957,6 @@ class FusedMoE(torch.nn.Module):
]
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
else
:
param_data
[
expert_id
]
=
loaded_weight
def
extra_repr
(
self
)
->
str
:
s
=
(
...
...
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
round_up
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,
)
tokens_cnts
=
torch
.
zeros
((
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,
)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_sorted_ids
:
bool
=
False
,
num_token
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
if
num_token
:
if
num_token
<
block_size
:
max_num_tokens_padded
=
min
(
topk_ids
.
numel
()
*
block_size
,
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
))
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
full
((
max_num_tokens_padded
,),
fill_value
=
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
or
num_experts
!=
256
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
# Currently requires num_experts=256
ops
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
import
torch
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
def
is_rocm_aiter_moe_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
and
envs
.
VLLM_ROCM_USE_AITER_MOE
\
and
envs
.
VLLM_ROCM_USE_AITER
\
def
is_rocm_aiter_block_scaled_moe_enabled
()
->
bool
:
return
is_rocm_aiter_moe_enabled
()
and
\
envs
.
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
def
rocm_aiter_fused_experts
(
*
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
expert_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwagrs
# Ignore additional keyword arguments
)
->
torch
.
Tensor
:
import
aiter
as
rocm_aiter
import
aiter.fused_moe_bf16_asm
as
rocm_aiter_asm_fmoe
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
if
envs
.
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
and
use_fp8_w8a8
:
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
local_E
=
E
=
w1
.
shape
[
0
]
if
expert_mask
is
not
None
:
E
=
expert_mask
.
numel
()
topk
=
topk_ids
.
shape
[
1
]
model_dim
=
w1
.
shape
[
-
1
]
dtype
=
hidden_states
.
dtype
# The default block sizes are 128 in AITER.
if
block_shape
is
None
:
block_shape
=
[
128
,
128
]
scale_blk_k
=
block_shape
[
1
]
(
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
out_asm
,
)
=
rocm_aiter_asm_fmoe
.
moe_sorting_ck
(
topk_ids
,
topk_weights
,
E
,
model_dim
,
dtype
,
expert_mask
=
expert_mask
)
a1
,
a1_scale
=
per_token_group_quant_fp8
(
hidden_states
,
scale_blk_k
)
rocm_aiter
.
fmoe_fp8_blockscale_g1u1
(
out_asm
,
a1
,
w1
,
w2
,
sorted_token_ids
,
sorted_weight_buf
,
sorted_expert_ids
,
num_valid_ids
,
topk
,
w1_scale
.
view
(
local_E
,
-
1
),
w2_scale
.
view
(
local_E
,
-
1
),
a1_scale
.
t
().
contiguous
(),
block_shape
[
0
],
block_shape
[
1
],
None
,
)
return
out_asm
elif
use_fp8_w8a8
:
return
rocm_aiter_asm_fmoe
.
asm_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weight
=
topk_weights
,
topk_ids
=
topk_ids
,
fc1_scale
=
w1_scale
,
fc2_scale
=
w2_scale
,
fc1_smooth_scale
=
None
,
fc2_smooth_scale
=
None
,
a16
=
False
)
return
rocm_aiter
.
ck_moe
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
)
def
rocm_aiter_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
...]:
import
aiter
as
rocm_aiter
rocm_aiter
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
renormalize
)
return
topk_weights
,
topk_indices
def
shuffle_weights
(
*
tensors
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Args:
*tensors: Variable number of torch.Tensor objects.
Returns:
A tuple of shuffled tensors.
"""
from
aiter.ops.shuffle
import
shuffle_weight
return
tuple
(
shuffle_weight
(
tensor
)
for
tensor
in
tensors
)
def
expand_weights
(
*
tensors
:
torch
.
Tensor
,
expansion_dims
:
list
[
int
])
->
tuple
[
torch
.
Tensor
,
...]:
"""
Expands the dimensions of input tensors.
Args:
*tensors: A variable number of torch.Tensor objects.
expansion_dims: A list of expansion dimensions
corresponding to each tensor.
Returns:
A tuple of tensors with expanded dimensions.
"""
assert
len
(
tensors
)
==
len
(
expansion_dims
),
\
"Number of tensors must match the number of expansion dimensions."
return
tuple
(
tensor
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
((
-
1
,
dim
,
-
1
))
for
tensor
,
dim
in
zip
(
tensors
,
expansion_dims
))
vllm/model_executor/layers/fused_moe/utils.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
math
import
prod
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
cdiv
def
_resize_cache
(
x
:
torch
.
Tensor
,
v
:
Tuple
[
int
,
...])
->
torch
.
Tensor
:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert
prod
(
v
)
<=
x
.
numel
()
return
x
.
flatten
()[:
prod
(
v
)].
view
(
*
v
)
def
_fp8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
block_shape
:
Optional
[
List
[
int
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
return
A
,
A_scale
def
_fp8_perm
(
m
:
torch
.
Tensor
,
idx
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
A permutation routine that works on fp8 types.
"""
if
torch
.
is_floating_point
(
m
)
and
torch
.
finfo
(
m
.
dtype
).
bits
==
8
:
return
m
.
view
(
dtype
=
torch
.
uint8
)[
idx
,
...].
view
(
dtype
=
m
.
dtype
)
else
:
return
m
[
idx
,
...]
vllm/model_executor/layers/layernorm.py
View file @
fcfc474d
...
...
@@ -109,6 +109,7 @@ class RMSNorm(CustomOp):
eps
:
float
=
1e-6
,
var_hidden_size
:
Optional
[
int
]
=
None
,
has_weight
:
bool
=
True
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -117,8 +118,10 @@ class RMSNorm(CustomOp):
self
.
variance_size_override
=
(
None
if
var_hidden_size
==
hidden_size
else
var_hidden_size
)
self
.
has_weight
=
has_weight
self
.
weight
=
torch
.
ones
(
hidden_size
)
if
dtype
is
not
None
:
self
.
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
)
else
:
self
.
weight
=
torch
.
ones
(
hidden_size
)
if
self
.
has_weight
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
)
...
...
vllm/model_executor/layers/lightning_attn.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
torch
import
triton
import
triton.language
as
tl
from
einops
import
rearrange
@
triton
.
jit
def
_fwd_diag_kernel
(
Q
,
K
,
V
,
Out
,
S
,
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
,
d
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
NUM_BLOCK
,
CBLOCK
:
tl
.
constexpr
):
# This kernel computes the diagonal blocks of the attention matrix
# Each diagonal block represents attention
# where queries attend to keys in the same block
off
=
tl
.
program_id
(
0
)
off_bh
=
off
//
NUM_BLOCK
# batch-head index
off_block
=
off
%
NUM_BLOCK
# block index within the sequence
off_cblock
=
tl
.
program_id
(
1
)
# sub-block index within a block
off_h
=
off_bh
%
h
# head index
# Calculate base offsets for the current batch and head
qk_offset
=
off_bh
*
n
*
d
v_offset
=
off_bh
*
n
*
e
o_offset
=
off_bh
*
n
*
e
# Calculate offsets for the current block
block_offset
=
off_block
*
BLOCK
qk_block_offset
=
block_offset
*
d
v_block_offset
=
block_offset
*
e
o_block_offset
=
block_offset
*
e
# Calculate offsets for the current sub-block
cblock_offset
=
off_cblock
*
CBLOCK
q_cblock_offset
=
cblock_offset
*
d
o_cblock_offset
=
cblock_offset
*
e
# Calculate pointers to the query, key, value, and output tensors
Q_block_ptr
=
(
Q
+
qk_offset
+
qk_block_offset
+
q_cblock_offset
+
tl
.
arange
(
0
,
CBLOCK
)[:,
None
]
*
d
+
tl
.
arange
(
0
,
d
)[
None
,
:])
K_trans_block_ptr
=
(
K
+
qk_offset
+
qk_block_offset
+
tl
.
arange
(
0
,
CBLOCK
)[
None
,
:]
*
d
+
tl
.
arange
(
0
,
d
)[:,
None
])
V_block_ptr
=
(
V
+
v_offset
+
v_block_offset
+
tl
.
arange
(
0
,
CBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
e
)[
None
,
:])
O_block_ptr
=
(
Out
+
o_offset
+
o_block_offset
+
o_cblock_offset
+
tl
.
arange
(
0
,
CBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
e
)[
None
,
:])
# Load the decay rate for the current head
S_block_ptr
=
S
+
off_h
s
=
tl
.
load
(
S_block_ptr
)
i
=
off_cblock
q_index
=
tl
.
arange
(
0
,
CBLOCK
)
+
i
*
CBLOCK
# Load query values
q
=
tl
.
load
(
Q_block_ptr
,
mask
=
block_offset
+
q_index
[:,
None
]
<
n
,
other
=
0.0
).
to
(
tl
.
float32
)
# Initialize output accumulator
qkv
=
tl
.
zeros
([
CBLOCK
,
e
],
dtype
=
tl
.
float32
)
# Process all sub-blocks up to and
# including the current one (causal attention)
for
j
in
range
(
i
+
1
):
kv_index
=
tl
.
arange
(
0
,
CBLOCK
)
+
j
*
CBLOCK
diff
=
q_index
[:,
None
]
-
kv_index
[
None
,
:]
s_index
=
s
*
diff
# Apply causal mask: only attend to positions before the current one
s_index
=
tl
.
where
(
diff
>=
0
,
-
s_index
,
float
(
"-inf"
))
decay
=
tl
.
exp
(
s_index
)
# Load key and value
k_trans
=
tl
.
load
(
K_trans_block_ptr
,
mask
=
block_offset
+
kv_index
[
None
,
:]
<
n
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v
=
tl
.
load
(
V_block_ptr
,
mask
=
block_offset
+
kv_index
[:,
None
]
<
n
,
other
=
0.0
,
).
to
(
tl
.
float32
)
# Compute attention scores and apply decay
qk
=
tl
.
dot
(
q
,
k_trans
)
*
decay
# Compute weighted values and accumulate
qkv
+=
tl
.
dot
(
qk
,
v
)
# Move to the next sub-block
K_trans_block_ptr
+=
CBLOCK
*
d
V_block_ptr
+=
CBLOCK
*
e
# Store the result
tl
.
store
(
O_block_ptr
,
qkv
.
to
(
O_block_ptr
.
dtype
.
element_ty
),
mask
=
block_offset
+
q_index
[:,
None
]
<
n
,
)
@
triton
.
jit
def
_fwd_kv_parallel
(
K
,
V
,
K_decay
,
KV
,
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
,
d
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
NUM_BLOCK
,
D_FBLOCK
:
tl
.
constexpr
,
E_FBLOCK
:
tl
.
constexpr
,
NUM_FBLOCK
:
tl
.
constexpr
,
CBLOCK
:
tl
.
constexpr
,
NUM_CBLOCK
:
tl
.
constexpr
,
):
# This kernel computes the key-value outer
# products for each block in parallel
off_bh
=
tl
.
program_id
(
0
)
# batch-head index
off_block
=
tl
.
program_id
(
1
)
# block index
off_h
=
off_bh
%
h
# head index
block_offset
=
off_block
*
BLOCK
# Calculate offsets for the current block
k_block_offset
=
block_offset
*
d
v_block_offset
=
block_offset
*
e
kv_block_offset
=
off_block
*
d
*
e
# Calculate base offsets for the current batch and head
k_offset
=
off_bh
*
n
*
d
v_offset
=
off_bh
*
n
*
e
kv_offset
=
off_bh
*
NUM_BLOCK
*
d
*
e
# Calculate pointers to the key, value, and key-value tensors
K_trans_block_ptr
=
(
K
+
k_offset
+
k_block_offset
+
tl
.
arange
(
0
,
CBLOCK
)[
None
,
:]
*
d
+
tl
.
arange
(
0
,
D_FBLOCK
)[:,
None
])
V_block_ptr
=
(
V
+
v_offset
+
v_block_offset
+
tl
.
arange
(
0
,
CBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
E_FBLOCK
)[
None
,
:])
KV_block_ptr
=
(
KV
+
kv_offset
+
kv_block_offset
+
tl
.
arange
(
0
,
D_FBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
E_FBLOCK
)[
None
,
:])
# Load the decay factors for the current head and block
k_decay_ptr
=
(
K_decay
+
off_h
*
BLOCK
+
tl
.
arange
(
0
,
CBLOCK
)[
None
,
:])
kv_index
=
tl
.
arange
(
0
,
CBLOCK
)
# Initialize the key-value outer product accumulator
kv
=
tl
.
zeros
([
D_FBLOCK
,
E_FBLOCK
],
dtype
=
tl
.
float32
)
# Handle the last block which might be smaller than BLOCK
if
off_block
==
NUM_BLOCK
-
1
:
split_n
=
n
-
(
NUM_BLOCK
-
1
)
*
BLOCK
else
:
split_n
=
BLOCK
left_shift
=
tl
.
cdiv
(
split_n
,
CBLOCK
)
*
CBLOCK
-
split_n
num_blocks
=
min
(
tl
.
cdiv
(
split_n
,
CBLOCK
),
NUM_CBLOCK
)
k_decay_ptr
+=
(
NUM_CBLOCK
-
num_blocks
)
*
CBLOCK
# Process all sub-blocks in the current block
for
j
in
range
(
num_blocks
):
left_bound
=
(
1
-
j
)
*
left_shift
# Load key and value, handling boundary conditions
k_trans
=
tl
.
load
(
K_trans_block_ptr
-
left_shift
*
d
,
mask
=
kv_index
[
None
,
:]
>=
left_bound
,
other
=
0.0
)
v
=
tl
.
load
(
V_block_ptr
-
left_shift
*
e
,
mask
=
kv_index
[:,
None
]
>=
left_bound
,
other
=
0.0
)
# Load decay factor and compute weighted key-value outer product
k_decay
=
tl
.
load
(
k_decay_ptr
)
kv
+=
tl
.
dot
(
k_trans
*
k_decay
,
v
)
# Move to the next sub-block
K_trans_block_ptr
+=
CBLOCK
*
d
V_block_ptr
+=
CBLOCK
*
e
k_decay_ptr
+=
CBLOCK
# Store the result
tl
.
store
(
KV_block_ptr
,
kv
.
to
(
KV_block_ptr
.
dtype
.
element_ty
))
@
triton
.
jit
def
_fwd_kv_reduce
(
S
,
KV
,
KV_HISTORY
,
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
,
d
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
NUM_BLOCK
,
D_FBLOCK
:
tl
.
constexpr
,
E_FBLOCK
:
tl
.
constexpr
):
# This kernel reduces the key-value outer products
# across blocks and updates the KV history
off_bh
=
tl
.
program_id
(
0
)
# batch-head index
off_h
=
off_bh
%
h
# head index
kv_offset
=
off_bh
*
NUM_BLOCK
*
d
*
e
# Calculate pointer to the key-value tensor
KV_block_ptr
=
(
KV
+
kv_offset
+
tl
.
arange
(
0
,
D_FBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
E_FBLOCK
)[
None
,
:])
# Load the decay rate for the current head
s_ptrs
=
S
+
off_h
s
=
tl
.
load
(
s_ptrs
)
# Calculate pointer to the key-value history tensor
kv_history_offset
=
off_bh
*
d
*
e
KV_HISTORY_block_ptr
=
(
KV_HISTORY
+
kv_history_offset
+
tl
.
arange
(
0
,
D_FBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
E_FBLOCK
)[
None
,
:])
# Load the previous key-value history
kv_pre
=
tl
.
load
(
KV_HISTORY_block_ptr
).
to
(
tl
.
float32
)
# Process all blocks in reverse order to compute the prefix sum
for
i
in
range
(
NUM_BLOCK
):
block_size
=
min
(
n
-
i
*
BLOCK
,
BLOCK
)
# Compute decay factor for the current block
block_decay
=
tl
.
exp
(
-
s
.
to
(
tl
.
float32
)
*
block_size
)
# Load the current key-value outer product
kv_cur
=
tl
.
load
(
KV_block_ptr
).
to
(
tl
.
float32
)
# Store the previous key-value history to the current block
tl
.
store
(
KV_block_ptr
,
kv_pre
.
to
(
KV_block_ptr
.
dtype
.
element_ty
))
# Update the key-value history with the current block
kv_pre
=
block_decay
*
kv_pre
+
kv_cur
KV_block_ptr
+=
d
*
e
# Store the updated key-value history
tl
.
store
(
KV_HISTORY_block_ptr
,
kv_pre
)
@
triton
.
jit
def
_fwd_none_diag_kernel
(
Q
,
Out
,
S
,
KV
,
b
:
tl
.
constexpr
,
h
:
tl
.
constexpr
,
n
,
d
:
tl
.
constexpr
,
e
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
NUM_BLOCK
,
E_FBLOCK
:
tl
.
constexpr
,
CBLOCK
:
tl
.
constexpr
,
NUM_CBLOCK
:
tl
.
constexpr
,
):
# This kernel computes the non-diagonal blocks of the attention matrix
# Each non-diagonal block represents attention
# where queries attend to keys in different blocks
off_bh
=
tl
.
program_id
(
0
)
# batch-head index
off_h
=
off_bh
%
h
# head index
off_nc
=
tl
.
program_id
(
1
)
off_n
=
off_nc
//
NUM_CBLOCK
# block index
off_c
=
off_nc
%
NUM_CBLOCK
# sub-block index
off_e
=
tl
.
program_id
(
2
)
# output feature block index
n_offset
=
off_n
*
BLOCK
c_offset
=
off_c
*
CBLOCK
e_offset
=
off_e
*
E_FBLOCK
block_offset
=
n_offset
+
c_offset
# Calculate offsets for the current batch, head, and block
q_offset
=
off_bh
*
n
*
d
+
(
n_offset
+
c_offset
)
*
d
o_offset
=
off_bh
*
n
*
e
+
(
n_offset
+
c_offset
)
*
e
+
e_offset
kv_offset
=
off_bh
*
NUM_BLOCK
*
d
*
e
+
off_n
*
d
*
e
+
e_offset
# Calculate pointers to the query, output, and key-value tensors
Q_block_ptr
=
(
Q
+
q_offset
+
tl
.
arange
(
0
,
CBLOCK
)[:,
None
]
*
d
+
tl
.
arange
(
0
,
d
)[
None
,
:])
O_block_ptr
=
(
Out
+
o_offset
+
tl
.
arange
(
0
,
CBLOCK
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
E_FBLOCK
)[
None
,
:])
KV_block_ptr
=
(
KV
+
kv_offset
+
tl
.
arange
(
0
,
d
)[:,
None
]
*
e
+
tl
.
arange
(
0
,
E_FBLOCK
)[
None
,
:])
# Load the decay rate for the current head
S_block_ptr
=
S
+
off_h
s
=
tl
.
load
(
S_block_ptr
)
c_array
=
tl
.
arange
(
0
,
CBLOCK
)
# Load the key-value outer product for the current block
kv
=
tl
.
load
(
KV_block_ptr
).
to
(
tl
.
float32
)
q_index
=
block_offset
+
tl
.
arange
(
0
,
CBLOCK
)
# Load query values
q
=
tl
.
load
(
Q_block_ptr
,
mask
=
q_index
[:,
None
]
<
n
,
other
=
0.
).
to
(
tl
.
float32
)
# Compute decay factors for the current sub-block
q_decay
=
tl
.
exp
(
-
s
.
to
(
tl
.
float32
)
*
(
off_c
*
CBLOCK
+
c_array
[:,
None
]))
# Compute non-diagonal attention output
qkv_none_diag
=
tl
.
dot
(
q
,
kv
)
*
q_decay
# Load diagonal attention output (computed by _fwd_diag_kernel)
qkv_diag
=
tl
.
load
(
O_block_ptr
,
mask
=
q_index
[:,
None
]
<
n
,
other
=
0.
).
to
(
tl
.
float32
)
# Combine diagonal and non-diagonal attention outputs
qkv
=
qkv_diag
+
qkv_none_diag
# Store the result
tl
.
store
(
O_block_ptr
,
qkv
.
to
(
O_block_ptr
.
dtype
.
element_ty
),
mask
=
q_index
[:,
None
]
<
n
)
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
s
,
kv_history
):
# Forward pass of the lightning attention algorithm
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
s
=
s
.
contiguous
()
# Check CUDA compute capability
capability
=
torch
.
cuda
.
get_device_capability
()
if
capability
[
0
]
<
8
:
raise
RuntimeError
(
"Flash attention currently only supported"
,
"for compute capability >= 80"
)
# Get input dimensions
b
,
h
,
n
,
d
=
q
.
shape
e
=
v
.
shape
[
-
1
]
# Initialize output tensor
o
=
torch
.
empty
((
b
,
h
,
n
,
e
),
dtype
=
q
.
dtype
,
device
=
q
.
device
)
# Set block sizes
BLOCK
=
256
NUM_BLOCK
=
triton
.
cdiv
(
n
,
BLOCK
)
CBLOCK
=
32
NUM_CBLOCK
=
BLOCK
//
CBLOCK
assert
BLOCK
%
CBLOCK
==
0
,
"BLOCK must be a multiple of CBLOCK"
# Compute decay factors for keys
array
=
torch
.
arange
(
0
,
BLOCK
,
device
=
q
.
device
)
+
1
k_decay
=
torch
.
exp
(
-
s
*
(
BLOCK
-
array
.
reshape
(
1
,
-
1
)))
# Step 1: Compute diagonal blocks of attention
grid
=
(
b
*
h
*
NUM_BLOCK
,
NUM_CBLOCK
)
_fwd_diag_kernel
[
grid
](
q
,
k
,
v
,
o
,
s
,
b
,
h
,
n
,
d
,
e
,
BLOCK
=
BLOCK
,
NUM_BLOCK
=
NUM_BLOCK
,
CBLOCK
=
CBLOCK
)
# Set feature block sizes
NUM_FBLOCK
=
1
D_FBLOCK
=
d
//
NUM_FBLOCK
assert
d
%
NUM_FBLOCK
==
0
E_FBLOCK
=
e
//
NUM_FBLOCK
assert
e
%
NUM_FBLOCK
==
0
CBLOCK
=
64
NUM_CBLOCK
=
BLOCK
//
CBLOCK
assert
BLOCK
%
CBLOCK
==
0
,
"BLOCK must be a multiple of CBLOCK"
# Step 2: Compute key-value outer products for each block in parallel
kv
=
torch
.
empty
((
b
,
h
,
NUM_BLOCK
,
d
,
e
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
grid
=
(
b
*
h
,
NUM_BLOCK
)
_fwd_kv_parallel
[
grid
](
k
,
v
,
k_decay
,
kv
,
b
,
h
,
n
,
d
,
e
,
BLOCK
=
BLOCK
,
NUM_BLOCK
=
NUM_BLOCK
,
D_FBLOCK
=
D_FBLOCK
,
E_FBLOCK
=
E_FBLOCK
,
NUM_FBLOCK
=
NUM_FBLOCK
,
CBLOCK
=
CBLOCK
,
NUM_CBLOCK
=
NUM_CBLOCK
,
)
# Step 3: Reduce key-value outer products
# across blocks and update KV history
grid
=
(
b
*
h
,
NUM_FBLOCK
)
_fwd_kv_reduce
[
grid
](
s
,
kv
,
kv_history
,
b
,
h
,
n
,
d
,
e
,
BLOCK
=
BLOCK
,
NUM_BLOCK
=
NUM_BLOCK
,
D_FBLOCK
=
D_FBLOCK
,
E_FBLOCK
=
E_FBLOCK
)
# Step 4: Compute non-diagonal blocks of attention
grid
=
(
b
*
h
,
NUM_BLOCK
*
NUM_CBLOCK
)
_fwd_none_diag_kernel
[
grid
](
q
,
o
,
s
,
kv
,
b
,
h
,
n
,
d
,
e
,
BLOCK
=
BLOCK
,
NUM_BLOCK
=
NUM_BLOCK
,
E_FBLOCK
=
E_FBLOCK
,
CBLOCK
=
CBLOCK
,
NUM_CBLOCK
=
NUM_CBLOCK
,
)
# Save tensors for backward pass
ctx
.
save_for_backward
(
q
,
k
,
v
,
s
,
kv
)
ctx
.
BLOCK
=
BLOCK
return
o
,
torch
.
cat
([
kv
,
kv_history
.
unsqueeze
(
2
)],
dim
=
2
)
# Apply the lightning attention function
lightning_attention_
=
_attention
.
apply
def
lightning_attention
(
q
,
k
,
v
,
ed
,
block_size
=
256
,
kv_history
=
None
):
"""
Apply lightning attention algorithm
to compute attention efficiently.
Args:
q: Query tensor of shape [batch, heads, seq_len, dim]
k: Key tensor of shape [batch, heads, seq_len, dim]
v: Value tensor of shape [batch, heads, seq_len, dim_v]
ed: Decay rate tensor of shape [heads]
block_size: Size of blocks for block-sparse attention
kv_history: Optional key-value history from previous computations
Returns:
output: Attention output
kv: Updated key-value history
"""
d
=
q
.
shape
[
-
1
]
e
=
v
.
shape
[
-
1
]
if
ed
.
dim
()
==
1
:
ed
=
ed
.
view
(
1
,
-
1
,
1
,
1
)
# Split the computation into chunks for better parallelism
m
=
128
if
d
>=
128
else
64
assert
d
%
m
==
0
,
f
"Dimension d (
{
d
}
) must be divisible by m (
{
m
}
)"
arr
=
[
m
*
i
for
i
in
range
(
d
//
m
+
1
)]
if
arr
[
-
1
]
!=
d
:
arr
.
append
(
d
)
n
=
len
(
arr
)
output
=
0
# Initialize or clone key-value history
if
kv_history
is
None
:
kv_history
=
torch
.
zeros
((
q
.
shape
[
0
],
q
.
shape
[
1
],
d
,
e
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
else
:
kv_history
=
kv_history
.
clone
().
contiguous
()
# Process each chunk and accumulate results
for
i
in
range
(
n
-
1
):
s
=
arr
[
i
]
e
=
arr
[
i
+
1
]
q1
=
q
[...,
s
:
e
]
k1
=
k
[...,
s
:
e
]
o
,
kv
=
lightning_attention_
(
q1
,
k1
,
v
,
ed
,
kv_history
)
output
=
output
+
o
return
output
,
kv
@
triton
.
jit
def
_linear_attn_decode_kernel
(
q_ptr
,
k_ptr
,
v_ptr
,
kv_cache_ptr
,
slope_rate
,
slot_idx
,
output_ptr
,
D
:
tl
.
constexpr
,
qkv_b_stride
,
qkv_h_stride
,
cache_b_stride
,
cache_h_stride
,
cache_d0_stride
,
cache_d1_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Kernel for linear attention decoding with KV cache.
This kernel computes attention for a single token using the KV cache.
"""
pid_b
=
tl
.
program_id
(
0
)
# batch index
pid_h
=
tl
.
program_id
(
1
)
# head index
pid_d
=
tl
.
program_id
(
2
)
# dimension block index
# Load slot index for the current batch
slot_id
=
tl
.
load
(
slot_idx
+
pid_b
)
# Skip if slot_id is -1 (padding)
if
slot_id
==
-
1
:
return
batch_id
=
pid_b
head_id
=
pid_h
# Load decay rate for the current head
ratio
=
tl
.
load
(
slope_rate
+
pid_h
)
# Calculate offsets for dimensions
qk_d_offsets
=
tl
.
arange
(
0
,
D
)
v_d_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid_d
*
BLOCK_SIZE
cache_d_offsets
=
qk_d_offsets
[:,
None
]
*
cache_d0_stride
+
v_d_offsets
[
None
,
:]
*
cache_d1_stride
# Calculate offsets for the current batch and head
q_offset
=
batch_id
*
qkv_b_stride
+
head_id
*
qkv_h_stride
k_offset
=
batch_id
*
qkv_b_stride
+
head_id
*
qkv_h_stride
v_offset
=
batch_id
*
qkv_b_stride
+
head_id
*
qkv_h_stride
cache_offset
=
slot_id
*
cache_b_stride
+
head_id
*
cache_h_stride
# Create masks for loading tensors
qk_mask
=
qk_d_offsets
<
D
v_mask
=
v_d_offsets
<
D
# Load query, key, and value tensors
q
=
tl
.
load
(
q_ptr
+
q_offset
+
qk_d_offsets
,
mask
=
qk_mask
,
other
=
0.0
)
k
=
tl
.
load
(
k_ptr
+
k_offset
+
qk_d_offsets
,
mask
=
qk_mask
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptr
+
v_offset
+
v_d_offsets
,
mask
=
v_mask
,
other
=
0.0
)
# Compute key-value outer product
kv_outer
=
k
[:,
None
]
*
v
[
None
,
:]
kv_mask
=
qk_mask
[:,
None
]
&
v_mask
[
None
,
:]
# Apply decay to previous KV cache
ratio
=
tl
.
exp
(
-
ratio
)
kv_ptr
=
kv_cache_ptr
+
cache_offset
+
cache_d_offsets
kv_cache_old
=
tl
.
load
(
kv_ptr
,
mask
=
kv_mask
,
other
=
0.0
)
kv_outer
=
kv_outer
+
ratio
*
kv_cache_old
# Compute attention output
output
=
q
[:,
None
].
to
(
tl
.
float32
)
*
kv_outer
output
=
tl
.
sum
(
output
,
axis
=
0
)
# Update KV cache and store output
tl
.
store
(
kv_ptr
,
kv_outer
,
mask
=
kv_mask
)
tl
.
store
(
output_ptr
+
q_offset
+
v_d_offsets
,
output
,
mask
=
v_mask
)
def
linear_decode_forward_triton
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_caches
:
torch
.
Tensor
,
slope_rate
:
torch
.
Tensor
,
slot_idx
:
torch
.
Tensor
,
BLOCK_SIZE
:
int
=
32
,
)
->
torch
.
Tensor
:
"""
Perform linear attention decoding using Triton kernels.
Args:
q: Query tensor of shape [B, H, 1, D]
k: Key tensor of shape [B, H, 1, D]
v: Value tensor of shape [B, H, 1, D]
kv_caches: Key-value cache tensor
slope_rate: Decay rate tensor
slot_idx: Slot indices for batches
BLOCK_SIZE: Size of blocks for processing
Returns:
output: Attention output tensor
"""
B
,
H
,
_
,
D
=
q
.
shape
assert
k
.
shape
==
(
B
,
H
,
1
,
D
)
assert
v
.
shape
==
(
B
,
H
,
1
,
D
)
# Initialize output tensor
output
=
torch
.
empty_like
(
q
)
# Set grid dimensions for the kernel
grid
=
(
B
,
H
,
D
//
BLOCK_SIZE
)
# Calculate strides for tensors
qkv_b_stride
=
q
.
stride
(
0
)
qkv_h_stride
=
q
.
stride
(
1
)
cache_b_stride
=
kv_caches
.
stride
(
0
)
cache_h_stride
=
kv_caches
.
stride
(
1
)
cache_d0_stride
=
kv_caches
.
stride
(
2
)
cache_d1_stride
=
kv_caches
.
stride
(
3
)
# Launch the kernel
_linear_attn_decode_kernel
[
grid
](
q
,
k
,
v
,
kv_caches
,
slope_rate
,
slot_idx
,
output
,
D
,
qkv_b_stride
,
qkv_h_stride
,
cache_b_stride
,
cache_h_stride
,
cache_d0_stride
,
cache_d1_stride
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Reshape output and return
output
=
rearrange
(
output
,
"b h n d -> b n (h d)"
)
return
output
.
squeeze
(
1
).
contiguous
()
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
fcfc474d
...
...
@@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
...
...
@@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
fcfc474d
...
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.utils
import
direct_register_custom_op
class
BitsAndBytesConfig
(
QuantizationConfig
):
...
...
@@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
original_type
=
x
.
dtype
original_shape
=
x
.
shape
reshape_after_matmul
=
False
...
...
@@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out_dim_1
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
current_index
=
0
for
i
in
range
(
len
(
quant_states
)):
output_size
=
quant_states
[
i
].
shape
[
0
]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out
[:,
current_index
:
current_index
+
output_size
]
=
matmul_4bit
(
bf_x
,
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]].
t
(),
quant_states
[
i
])
current_index
+=
output_size
apply_bnb_4bit
(
bf_x
,
qweight
,
offsets
,
out
)
out
=
out
.
to
(
original_type
)
if
reshape_after_matmul
:
...
...
@@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out
+=
bias
return
out
def
_apply_bnb_4bit
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
)
->
None
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
quant_states
=
weight
.
bnb_quant_state
current_index
=
0
for
i
in
range
(
len
(
quant_states
)):
output_size
=
quant_states
[
i
].
shape
[
0
]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out
[:,
current_index
:
current_index
+
output_size
]
=
matmul_4bit
(
x
,
weight
[
offsets
[
i
]:
offsets
[
i
+
1
]].
t
(),
quant_states
[
i
])
current_index
+=
output_size
def
_apply_bnb_4bit_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
)
->
None
:
return
try
:
direct_register_custom_op
(
op_name
=
"apply_bnb_4bit"
,
op_func
=
_apply_bnb_4bit
,
mutates_args
=
[
"out"
],
fake_impl
=
_apply_bnb_4bit_fake
,
)
apply_bnb_4bit
=
torch
.
ops
.
vllm
.
apply_bnb_4bit
except
AttributeError
as
error
:
raise
error
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
fcfc474d
...
...
@@ -97,7 +97,8 @@ class CompressedTensorsConfig(QuantizationConfig):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
)
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
,
layer
.
activation
,
layer
.
expert_map
)
return
None
@
classmethod
...
...
@@ -192,17 +193,26 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
error
:
bool
=
True
,
match_exact
:
bool
=
False
)
->
bool
:
capability_tuple
=
current_platform
.
get_device_capability
()
if
capability_tuple
is
not
None
:
capability
=
capability_tuple
.
to_int
()
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
"Quantization scheme is not supported for "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
if
match_exact
:
supported
=
capability
==
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
"Quantization scheme is not supported for "
,
"the current GPU. Required capability: "
,
f
"
{
min_capability
}
. Current capability:
{
capability
}
."
)
else
:
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
"Quantization scheme is not supported for "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
return
supported
else
:
return
False
...
...
@@ -263,6 +273,11 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
return
is_symmetric_activation
and
is_per_tensor_activation
def
_is_fp8_w8a8_sm90
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
return
(
self
.
_check_scheme_supported
(
90
,
error
=
False
,
match_exact
=
True
)
and
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
))
def
_is_fp8_w8a16
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
# Confirm weights quantized.
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
fcfc474d
...
...
@@ -31,6 +31,7 @@ class GPTQMarlinState(Enum):
__all__
=
[
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsW8A8Fp8MoECutlassMethod"
,
"CompressedTensorsWNA16MoEMethod"
]
...
...
@@ -39,7 +40,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@
staticmethod
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
activation
:
str
,
expert_map
:
Optional
[
torch
.
Tensor
],
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
...
...
@@ -49,6 +52,9 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
and
activation
==
"silu"
and
expert_map
is
None
):
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
else
:
...
...
@@ -218,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -234,20 +241,245 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsW8A8Fp8MoECutlassMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_tensor
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
(
per_tensor
or
per_channel
):
raise
ValueError
(
"For FP8 Fused MoE layers, we require per tensor "
"or channelwise, dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
and
per_channel
:
raise
ValueError
(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
elif
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
static_input_scales
:
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
device
=
w13_weight
.
device
# TODO strides can be shared across multiple layers
self
.
ab_strides1
=
torch
.
full
((
num_experts
,
),
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides1
=
torch
.
full
((
num_experts
,
),
2
*
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
ab_strides2
=
torch
.
full
((
num_experts
,
),
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides2
=
torch
.
full
((
num_experts
,
),
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
assert
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
)):
logger
.
warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
:
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
local_num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
assert
global_num_experts
==
layer
.
w13_weight
.
shape
[
0
]
assert
expert_map
is
None
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
from
vllm.model_executor.layers.fused_moe
import
cutlass_moe_fp8
return
cutlass_moe_fp8
(
x
,
layer
.
w13_weight
.
transpose
(
1
,
2
),
layer
.
w2_weight
.
transpose
(
1
,
2
),
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
topk_weights
,
topk_ids
,
self
.
ab_strides1
,
self
.
c_strides1
,
self
.
ab_strides2
,
self
.
c_strides2
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
out_dtype
=
x
.
dtype
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
...
...
@@ -551,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
...
...
@@ -558,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for "
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
fcfc474d
...
...
@@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
...
...
@@ -143,5 +144,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
vllm/model_executor/layers/quantization/experts_int8.py
View file @
fcfc474d
...
...
@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -129,18 +130,20 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
use_int8_w8a16
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
use_int8_w8a16
=
True
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
)
@
staticmethod
def
quantizing_weight_loader
(
layer
,
weight_loader
):
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
fcfc474d
...
...
@@ -73,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
True
)
self
.
out_dtype
=
torch
.
get_default_dtype
()
def
create_weights
(
self
,
...
...
@@ -161,6 +162,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
)
vllm/model_executor/layers/quantization/fp8.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
importlib.util
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger
=
init_logger
(
__name__
)
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
def
_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
assert
x
.
dim
()
==
3
b
,
m
,
n
=
x
.
shape
return
x
.
stride
(
0
)
==
m
*
n
and
x
.
stride
(
1
)
==
1
and
x
.
stride
(
2
)
==
m
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
...
...
@@ -116,6 +125,21 @@ class Fp8Config(QuantizationConfig):
return
Fp8KVCacheMethod
(
self
)
return
None
def
get_cache_scale
(
self
,
name
:
str
)
->
Optional
[
str
]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if
name
.
endswith
(
".output_scale"
)
and
".k_proj"
in
name
:
return
name
.
replace
(
".k_proj.output_scale"
,
".attn.k_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".v_proj"
in
name
:
return
name
.
replace
(
".v_proj.output_scale"
,
".attn.v_scale"
)
return
None
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
...
...
@@ -138,6 +162,7 @@ class Fp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
self
.
out_dtype
=
torch
.
get_default_dtype
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
...
...
@@ -386,6 +411,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
...
...
@@ -407,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
# Check for DeepGemm support.
self
.
allow_deep_gemm
=
False
if
envs
.
VLLM_USE_DEEP_GEMM
:
if
not
has_deep_gemm
:
logger
.
warning_once
(
"Failed to import DeepGemm kernels."
)
elif
(
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
90
)):
logger
.
info_once
(
"Using DeepGemm kernels for Fp8MoEMethod."
)
self
.
allow_deep_gemm
=
True
else
:
logger
.
warning_once
(
"DeepGemm not supported on the current platform."
)
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
...
...
@@ -529,6 +568,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Lazy import to avoid importing triton too early.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
expand_weights
,
is_rocm_aiter_block_scaled_moe_enabled
,
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
# TODO (rob): refactor block quant into separate class.
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
...
...
@@ -554,6 +598,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
Parameter
(
w2_weight_scale_inv
,
requires_grad
=
False
)
if
is_rocm_aiter_block_scaled_moe_enabled
():
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if
self
.
allow_deep_gemm
:
# Lazy import to avoid CUDA initialization problems.
import
deep_gemm
as
dg
if
_is_col_major
(
layer
.
w13_weight_scale_inv
):
layer
.
w13_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w13_weight_scale_inv
).
contiguous
()
if
_is_col_major
(
layer
.
w2_weight_scale_inv
):
layer
.
w2_weight_scale_inv
=
\
dg
.
get_col_major_tma_aligned_tensor
(
layer
.
w2_weight_scale_inv
).
contiguous
()
return
# If checkpoint is fp16, quantize in place.
...
...
@@ -581,6 +647,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
if
is_rocm_aiter_moe_enabled
():
# reshaping weights is required for aiter moe kernel.
w13_scales
,
w2_scales
=
expand_weights
(
layer
.
w13_weight_scale
.
data
,
layer
.
w2_weight_scale
.
data
,
expansion_dims
=
[
layer
.
w13_weight
.
shape
[
1
],
layer
.
w2_weight
.
shape
[
1
]
])
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_scales
.
contiguous
(),
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_scales
.
contiguous
(),
requires_grad
=
False
)
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
...
...
@@ -648,6 +734,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
if
is_rocm_aiter_moe_enabled
():
# reshaping weights is required for aiter moe kernel.
expansion_dims
=
[
layer
.
w13_weight
.
shape
[
1
],
layer
.
w2_weight
.
shape
[
1
]
]
max_w13_scales
,
w2_scales
=
expand_weights
(
max_w13_scales
,
layer
.
w2_weight_scale
.
data
,
expansion_dims
=
expansion_dims
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_scales
.
contiguous
(),
requires_grad
=
False
)
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffled_w13
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
shuffled_w2
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
...
...
@@ -667,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -694,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation
=
activation
,
use_fp8_w8a8
=
True
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale_inv
if
self
.
block_quant
else
layer
.
w13_weight_scale
),
...
...
@@ -702,6 +810,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
allow_deep_gemm
=
self
.
allow_deep_gemm
,
)
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
fcfc474d
...
...
@@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
elif
qweight_type
in
DEQUANT_TYPES
:
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
shape
=
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
//
type_size
*
block_size
)
weight
=
ops
.
ggml_dequantize
(
qweight
,
qweight_type
,
*
shape
)
weight
=
ops
.
ggml_dequantize
(
qweight
,
qweight_type
,
*
shape
,
x
.
dtype
)
y
=
x
@
weight
.
T
else
:
# Raise an error if the quantization type is not supported.
...
...
@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
):
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
"Apply router weight on input is not supported for"
"fused GGUF MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -377,7 +383,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x_flat
=
x
.
flatten
()
quant
=
torch
.
index_select
(
qweight
,
dim
=
0
,
index
=
x_flat
)
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
x_flat
.
shape
[
0
]
).
to
(
self
.
params_dtype
)
x_flat
.
shape
[
0
]
,
self
.
params_dtype
)
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment