Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -594,7 +594,11 @@ def main(args: argparse.Namespace): ...@@ -594,7 +594,11 @@ def main(args: argparse.Namespace):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): elif config.architectures[0] in (
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
...@@ -678,7 +682,11 @@ def main(args: argparse.Namespace): ...@@ -678,7 +682,11 @@ def main(args: argparse.Namespace):
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16, block_quant_shape) search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
print(f"Start tuning over {len(search_space)} configurations...") print(f"Start tuning over {len(search_space)} configurations...")
if use_deep_gemm:
raise ValueError(
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
"kernels. Please remove the flag."
)
start = time.time() start = time.time()
configs = _distribute( configs = _distribute(
"tune", "tune",
......
This diff is collapsed.
...@@ -259,6 +259,7 @@ if __name__ == "__main__": ...@@ -259,6 +259,7 @@ if __name__ == "__main__":
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype) # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None), (None, None, None),
(None, FP8_DTYPE, None), (None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
] ]
......
...@@ -274,6 +274,7 @@ if __name__ == "__main__": ...@@ -274,6 +274,7 @@ if __name__ == "__main__":
quant_dtypes = [ quant_dtypes = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype) # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None), (None, None, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
] ]
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#pragma once #pragma once
#include "cutlass_extensions/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective { namespace cutlass::gemm::collective {
using namespace cute; using namespace cute;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment