Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd363067
Commit
bd363067
authored
Jun 05, 2025
by
lizhigong
Browse files
Merge branch 'v0.8.5.post1-dev' into v0.8.5-zero_overhead
parents
87ef4618
d36deb1a
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
741 additions
and
275 deletions
+741
-275
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+28
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+153
-79
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+6
-9
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+8
-2
vllm/envs.py
vllm/envs.py
+20
-10
vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI_nn.json
.../fused_moe/configs/E=64,N=640,device_name=K100_AI_nn.json
+164
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI_nn.json
.../fused_moe/configs/E=8,N=2048,device_name=K100_AI_nn.json
+60
-60
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI_nn.json
.../fused_moe/configs/E=8,N=3584,device_name=K100_AI_nn.json
+164
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+51
-15
vllm/model_executor/layers/quantization/blockwise_int8.py
vllm/model_executor/layers/quantization/blockwise_int8.py
+5
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+28
-3
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+5
-1
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+28
-36
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-2
vllm/model_executor/layers/quantization/w8a8_int8.py
vllm/model_executor/layers/quantization/w8a8_int8.py
+5
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+3
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+1
-50
No files found.
tests/worker/test_model_runner.py
View file @
bd363067
...
...
@@ -27,7 +27,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
def
test_deepseek_mla_attn_backend_module
():
model_runner
=
_create_model_runner
(
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
,
os
.
path
.
join
(
models_path_prefix
,
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
)
,
trust_remote_code
=
True
,
enable_chunked_prefill
=
False
,
)
...
...
@@ -383,4 +383,4 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert
attr_expected
[
1
]
==
attr_actual
[
1
]
for
attr_expected
,
attr_actual
in
zip
(
vars
(
attn_metadata
.
decode_metadata
),
vars
(
decode_meta_actual
)):
assert
attr_expected
[
1
]
==
attr_actual
[
1
]
assert
attr_expected
[
1
]
==
attr_actual
[
1
]
\ No newline at end of file
vllm/_custom_ops.py
View file @
bd363067
...
...
@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# return out
def
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
=
0
,
routed_scaling_factor
=
0
,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return
torch
.
ops
.
_moe_C
.
moe_fused_gate
(
input_tensor
,
bias
,
num_expert_group
,
topk_group
,
topk
,
n_share_experts_fusion
,
routed_scaling_factor
,
)
vllm/attention/backends/flash_attn.py
View file @
bd363067
...
...
@@ -27,8 +27,13 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_rocm
():
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
else
:
from
flash_attn
import
(
flash_attn_varlen_func
,
vllm_flash_attn_varlen_func
,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -807,25 +812,41 @@ class FlashAttentionImpl(AttentionImpl):
(
num_kv_tokens
,
num_kv_heads
,
head_size
))
descale_shape
=
(
q_seq_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
else
:
# prefix-enabled attention
assert
attn_type
==
AttentionType
.
DECODER
,
(
...
...
@@ -835,26 +856,48 @@ class FlashAttentionImpl(AttentionImpl):
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
descale_shape
=
(
prefill_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
vllm_flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
...
...
@@ -870,26 +913,43 @@ class FlashAttentionImpl(AttentionImpl):
assert
decode_meta
.
query_start_loc
is
not
None
descale_shape
=
(
decode_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
seqused_k
=
decode_meta
.
seq_lens_tensor
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
not
current_platform
.
is_rocm
():
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
seqused_k
=
decode_meta
.
seq_lens_tensor
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
decode_output
=
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
seqused_k
=
decode_meta
.
seq_lens_tensor
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
(
...
...
@@ -898,23 +958,37 @@ class FlashAttentionImpl(AttentionImpl):
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
descale_shape
=
(
seq_lens_arg
.
shape
[
0
],
key_cache
.
shape
[
-
2
])
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
not
current_platform
.
is_rocm
():
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
)
return
output
...
...
@@ -998,4 +1072,4 @@ def _get_causal_option(attn_type: str) -> bool:
"""
return
not
(
attn_type
==
AttentionType
.
ENCODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
or
attn_type
==
AttentionType
.
ENCODER_DECODER
)
or
attn_type
==
AttentionType
.
ENCODER_DECODER
)
\ No newline at end of file
vllm/attention/backends/rocm_flash_attn.py
View file @
bd363067
...
...
@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
from
vllm.utils
import
SUPPORT_TC
from
vllm.utils
import
SUPPORT_TC
,
gpuname
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
...
...
@@ -578,7 +578,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
fa_attn_func
=
flash_attn_varlen_func
if
not
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
if
not
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
and
gpuname
.
startswith
(
'K100_AI'
)
:
from
flash_attn
import
vllm_flash_attn_varlen_func
self
.
fa_prefix_attn_func
=
vllm_flash_attn_varlen_func
...
...
@@ -852,9 +852,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
# prefix-enabled attention -
# not applicable for encoder-only models
# if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
# self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
if
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
if
envs
.
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
or
gpuname
.
startswith
(
'BW'
):
version_key
=
triton_key
()
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
...
...
@@ -922,10 +920,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
# use_custom = use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len, self.sliding_window)
use_custom
=
False
use_custom
=
use_rocm_custom_paged_attention
(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
decode_meta
.
max_decode_seq_len
,
self
.
sliding_window
)
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
...
...
vllm/engine/arg_utils.py
View file @
bd363067
# SPDX-License-Identifier: Apache-2.0
# yapf: disable
import
os
import
argparse
import
dataclasses
import
json
...
...
@@ -35,12 +36,12 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
GiB_bytes
,
is_in_ray_actor
# yapf: enable
logger
=
init_logger
(
__name__
)
ALLOWED_DETAILED_TRACE_MODULES
=
[
"model"
,
"worker"
,
"all"
]
models_path_prefix
=
os
.
getenv
(
'VLLM_OPTEST_MODELS_PATH'
)
or
os
.
getenv
(
"OPTEST_MODELS_PATH"
)
# object is used to allow for special typing forms
T
=
TypeVar
(
"T"
)
...
...
@@ -203,7 +204,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
@
dataclass
class
EngineArgs
:
"""Arguments for vLLM engine."""
model
:
str
=
'facebook/opt-125m'
model
:
str
=
os
.
path
.
join
(
models_path_prefix
,
'facebook/opt-125m'
)
if
models_path_prefix
is
not
None
else
'facebook/opt-125m'
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
tokenizer
:
Optional
[
str
]
=
None
hf_config_path
:
Optional
[
str
]
=
None
...
...
vllm/engine/llm_engine.py
View file @
bd363067
...
...
@@ -2062,8 +2062,14 @@ class LLMEngine:
prompt_type
:
Literal
[
"encoder"
,
"decoder"
],
):
model_config
=
self
.
model_config
tokenizer
=
(
None
if
self
.
tokenizer
is
None
else
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
))
if
self
.
tokenizer
is
None
:
tokenizer
=
None
elif
self
.
model_config
.
tokenizer_mode
!=
"cpm"
:
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
else
:
tokenizer
=
self
.
tokenizer
# tokenizer = (None if self.tokenizer is None else
# self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids
=
prompt_inputs
[
"prompt_token_ids"
]
if
not
prompt_ids
:
...
...
vllm/envs.py
View file @
bd363067
...
...
@@ -124,11 +124,12 @@ if TYPE_CHECKING:
VLLM_PCIE_USE_CUSTOM_ALLREDUCE
:
bool
=
False
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_HAS_CONTEXT_DEFAULT
:
bool
=
False
VLLM_FLASH_ATTN_BACKEND
:
bool
=
False
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
VLLM_ZERO_OVERHEAD
:
bool
=
False
VLLM_ENABLE_MOE_FUSED_GATE
:
bool
=
False
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -800,11 +801,24 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_ENFORCE_EAGER_BS_THRESHOLD"
,
"-1"
)),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_HAS_CONTEXT_DEFAULT"
,
"0"
))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_BACKEND"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# Enable two batch overlap.
"VLLM_ENABLE_TBO"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_TBO"
,
"0"
))),
# set delay on server when only one requet, the purpose is to merge a larger batch.
"VLLM_TBO_REQ_DELAY_MS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_TBO_REQ_DELAY_MS"
,
"0"
)),
...
...
@@ -813,13 +827,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ZERO_OVERHEAD"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZERO_OVERHEAD"
,
"0"
))),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_HAS_CONTEXT_DEFAULT"
,
"0"
))),
# If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_MOE_FUSED_GATE"
,
"1"
))),
}
# end-env-vars-definition
...
...
vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI_nn.json
0 → 100644
View file @
bd363067
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
,
"num_ldmatrixes"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
5
,
"num_ldmatrixes"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
,
"num_ldmatrixes"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"1536"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI_nn.json
View file @
bd363067
...
...
@@ -3,97 +3,97 @@
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
6
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
2
,
"num_stages"
:
3
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"GROUP_SIZE_M"
:
1
6
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
1
28
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"BLOCK_SIZE_M"
:
1
6
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
6
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"96"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"GROUP_SIZE_M"
:
1
6
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
6
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
...
...
@@ -102,14 +102,14 @@
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"GROUP_SIZE_M"
:
1
6
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
...
...
@@ -117,17 +117,17 @@
"num_ldmatrixes"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"1536"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
...
...
@@ -135,8 +135,8 @@
"num_ldmatrixes"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
...
...
@@ -144,20 +144,20 @@
"num_ldmatrixes"
:
1
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
}
...
...
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI_nn.json
0 → 100644
View file @
bd363067
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
,
"num_ldmatrixes"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
1
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
bd363067
...
...
@@ -3,6 +3,7 @@
import
functools
import
json
import
os
import
math
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -1182,6 +1183,10 @@ def fused_topk(
return
topk_weights
,
topk_ids
def
is_power_of_two
(
n
):
return
n
>
0
and
math
.
log2
(
n
).
is_integer
()
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
grouped_topk
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
bd363067
...
...
@@ -19,10 +19,13 @@ from vllm.logger import init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
,
is_power_of_two
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
if
current_platform
.
is_cuda_alike
():
from
.fused_moe
import
fused_experts
...
...
@@ -174,6 +177,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
...
...
@@ -191,7 +196,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias
=
e_score_correction_bias
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
)
use_nn_moe
=
use_nn_moe
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
def
forward_cuda
(
self
,
...
...
@@ -211,6 +218,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -222,7 +231,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
hidden_states
=
x
,
...
...
@@ -255,6 +266,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
kwargs
,
):
assert
activation
==
"silu"
,
f
"
{
activation
}
is not supported."
...
...
@@ -290,6 +303,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
assert
not
use_grouped_topk
assert
num_expert_group
is
None
...
...
@@ -436,6 +451,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
super
().
__init__
()
...
...
@@ -505,6 +521,7 @@ class FusedMoE(torch.nn.Module):
self
.
scoring_func
=
scoring_func
self
.
e_score_correction_bias
=
e_score_correction_bias
self
.
activation
=
activation
self
.
routed_scaling_factor
=
routed_scaling_factor
if
self
.
scoring_func
!=
"softmax"
and
not
self
.
use_grouped_topk
:
raise
ValueError
(
"Only softmax scoring function is supported for "
...
...
@@ -554,9 +571,16 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self
.
use_fused_gate
=
envs
.
VLLM_ENABLE_MOE_FUSED_GATE
\
and
self
.
e_score_correction_bias
is
not
None
\
and
self
.
global_num_experts
//
num_expert_group
<=
32
\
and
is_power_of_two
(
e_score_correction_bias
.
shape
[
0
])
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
...
...
@@ -839,23 +863,33 @@ class FusedMoE(torch.nn.Module):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
)
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
):
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
if
use_fused_gate
:
topk_weights
,
topk_ids
=
ops
.
moe_fused_gate
(
router_logits
,
e_score_correction_bias
,
num_expert_group
,
topk_group
,
top_k
,
routed_scaling_factor
=
routed_scaling_factor
,
n_share_experts_fusion
=
0
,
)
else
:
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
...
...
@@ -927,6 +961,8 @@ class FusedMoE(torch.nn.Module):
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
if
self
.
dp_size
>
1
:
...
...
vllm/model_executor/layers/quantization/blockwise_int8.py
View file @
bd363067
...
...
@@ -384,6 +384,8 @@ class BlockInt8MoEMethod:
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -399,7 +401,9 @@ class BlockInt8MoEMethod:
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
# Expert fusion with INT8 quantization
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
bd363067
...
...
@@ -31,7 +31,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
from
vllm.utils
import
W8a8GetCacheJSON
import
os
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
...
...
@@ -540,10 +543,32 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
scheme
.
process_weights_after_loading
(
layer
)
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
{
n
,
k
}
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
bd363067
...
...
@@ -296,6 +296,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
...
...
@@ -309,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
...
...
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
bd363067
...
...
@@ -300,8 +300,8 @@ def _w8a8_block_int8_matmul(
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and
store the result in output tensor `C`.
product) on input tensors `A` and `B` with block-wise quantization,
and
store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
...
...
@@ -316,16 +316,29 @@ def _w8a8_block_int8_matmul(
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn
=
pid_n
*
BLOCK_SIZE_N
//
group_n
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
#
offs_bsn = offs_bn // group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
...
...
@@ -333,16 +346,13 @@ def _w8a8_block_int8_matmul(
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
...
...
@@ -435,30 +445,11 @@ def w8a8_block_int8_matmul(
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
# # Default config
# # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
# config = {
# "BLOCK_SIZE_M": 64,
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
#print("W8A8_TRITONJSON.triton_json_dict[0]:",W8A8_TRITONJSON.triton_json_dict[0])
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
if
len
(
W8A8_TRITONJSON
.
triton_json_list
)
==
0
:
config
=
None
#print("len(W8A8_TRITONJSON.triton_json_dict)=0:",len(W8A8_TRITONJSON.triton_json_dict))
elif
f
"1_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
elif
f
"1_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
in
W8A8_TRITONJSON
.
triton_json_list
[
0
]:
if
M
<=
16
:
m_
=
M
elif
M
<=
64
:
...
...
@@ -480,12 +471,13 @@ def w8a8_block_int8_matmul(
m_
=
4096
else
:
m_
=
8192
#print("==================m:{},n:{},k:{}".format(M,N,K))
config
=
W8A8_TRITONJSON
.
triton_json_
dic
t
[
0
][
f
"
{
m_
}
_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
]
config
=
W8A8_TRITONJSON
.
triton_json_
lis
t
[
0
][
f
"
{
m_
}
_
{
N
}
_
{
K
}
_block[
{
block_n
}
,
{
block_k
}
]"
]
else
:
config
=
None
config
=
None
if
config
==
None
:
# print("m:{},n:{},k:{}".format(M,N,K))
# print("config not found!")
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
bd363067
...
...
@@ -420,7 +420,7 @@ def apply_int8_linear(
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]
:
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
...
...
@@ -444,7 +444,7 @@ def apply_int8_linear(
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
...
...
vllm/model_executor/layers/quantization/w8a8_int8.py
View file @
bd363067
...
...
@@ -252,6 +252,8 @@ class W8A8Int8MoEMethod:
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -266,7 +268,9 @@ class W8A8Int8MoEMethod:
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
return
fused_experts
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
bd363067
...
...
@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,)
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
...
@@ -961,7 +962,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if
configs_dict
:
all_json
.
update
(
configs_dict
)
self
.
tritonsingleton
.
triton_json_
dic
t
.
append
(
all_json
)
self
.
tritonsingleton
.
triton_json_
lis
t
.
append
(
all_json
)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
...
...
vllm/model_executor/models/gpt_neox.py
View file @
bd363067
...
...
@@ -46,7 +46,6 @@ from .interfaces import SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
from
vllm
import
_custom_ops
as
ops
class
GPTNeoXAttention
(
nn
.
Module
):
...
...
@@ -219,12 +218,10 @@ class GPTNeoXModel(nn.Module):
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_in
(
input_ids
)
...
...
@@ -287,53 +284,7 @@ class GPTNeoXModel(nn.Module):
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
#当为triton支持推理的时候不能进行处理
if
self
.
quant_method
==
"compressed_tensors"
:
os
.
environ
[
'LM_NN'
]
=
'0'
lay_key_words
=
[
"attention.query_key_value.weight"
,
"attention.dense.weight"
,
"mlp.dense_h_to_4h.weight"
,
"mlp.dense_4h_to_h.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
matched_key_words
=
set
()
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
n
=
weight_data
.
shape
[
0
]
k
=
weight_data
.
shape
[
1
]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
matched_key_words
)
<
4
and
matches
[
0
]
not
in
matched_key_words
:
matched_key_words
.
add
(
matches
[
0
])
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
return
loaded_params
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment