Commit de889cb6 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent c721b814
...@@ -197,7 +197,7 @@ def bench_run( ...@@ -197,7 +197,7 @@ def bench_run(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
...@@ -242,7 +242,7 @@ def bench_run( ...@@ -242,7 +242,7 @@ def bench_run(
) )
kernel = mk.FusedMoEModularKernel( kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4( CutlassExpertsFp4(
make_dummy_moe_config(), make_dummy_moe_config(),
quant_config=quant_config, quant_config=quant_config,
...@@ -520,4 +520,4 @@ if __name__ == "__main__": ...@@ -520,4 +520,4 @@ if __name__ == "__main__":
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
...@@ -10,6 +10,8 @@ from transformers import AutoConfig ...@@ -10,6 +10,8 @@ from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute,
_moe_unpermute_and_reduce,
moe_permute, moe_permute,
moe_unpermute, moe_unpermute,
) )
...@@ -39,6 +41,7 @@ def benchmark_permute( ...@@ -39,6 +41,7 @@ def benchmark_permute(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
use_customized_permute: bool = False,
) -> float: ) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
...@@ -61,14 +64,29 @@ def benchmark_permute( ...@@ -61,14 +64,29 @@ def benchmark_permute(
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
def run(): def run():
moe_permute( if use_customized_permute:
qhidden_states, (
a1q_scale=None, permuted_hidden_states,
topk_ids=topk_ids, a1q_scale,
n_expert=num_experts, first_token_off,
expert_map=None, inv_perm_idx,
align_block_size=align_block_size, m_indices,
) ) = moe_permute(
qhidden_states,
a1q_scale=None,
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(qhidden_states, None, topk_ids, num_experts, None, 16)
# JIT compilation & warmup # JIT compilation & warmup
run() run()
...@@ -113,9 +131,11 @@ def benchmark_unpermute( ...@@ -113,9 +131,11 @@ def benchmark_unpermute(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
use_customized_permute: bool = False,
) -> float: ) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8: if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None) qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
...@@ -130,37 +150,78 @@ def benchmark_unpermute( ...@@ -130,37 +150,78 @@ def benchmark_unpermute(
) )
def prepare(): def prepare():
( if use_customized_permute:
permuted_hidden_states, (
_, permuted_hidden_states,
first_token_off, a1q_scale,
inv_perm_idx, first_token_off,
_, inv_perm_idx,
) = moe_permute( m_indices,
qhidden_states, ) = moe_permute(
a1q_scale=None, qhidden_states,
topk_ids=topk_ids, a1q_scale=None,
n_expert=num_experts, topk_ids=topk_ids,
expert_map=None, n_expert=num_experts,
align_block_size=align_block_size, expert_map=None,
) align_block_size=align_block_size,
# convert to fp16/bf16 as gemm output )
return ( # convert to fp16/bf16 as gemm output
permuted_hidden_states.to(dtype), return (
first_token_off, permuted_hidden_states.to(dtype),
inv_perm_idx, first_token_off,
) inv_perm_idx,
m_indices,
)
else:
(
permuted_qhidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = _moe_permute(
qhidden_states, None, topk_ids, num_experts, None, block_m=16
)
# convert to fp16/bf16 as gemm output
return (
permuted_qhidden_states.to(dtype),
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
)
def run(input: tuple): def run(input: tuple):
(permuted_hidden_states, first_token_off, inv_perm_idx) = input if use_customized_permute:
output = torch.empty_like(hidden_states) (
moe_unpermute( permuted_hidden_states,
output, first_token_off,
permuted_hidden_states, inv_perm_idx,
topk_weights, m_indices,
inv_perm_idx, ) = input
first_token_off, output = torch.empty_like(hidden_states)
) moe_unpermute(
output,
permuted_hidden_states,
topk_weights,
inv_perm_idx,
first_token_off,
)
else:
(
permuted_hidden_states,
a1q_scale,
sorted_token_ids,
expert_ids,
inv_perm,
) = input
_moe_unpermute_and_reduce(
output_hidden_states,
permuted_hidden_states,
inv_perm,
topk_weights,
True,
)
# JIT compilation & warmup # JIT compilation & warmup
input = prepare() input = prepare()
...@@ -215,7 +276,8 @@ class BenchmarkWorker: ...@@ -215,7 +276,8 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
) -> tuple[float, float]: use_customized_permute: bool = False,
) -> tuple[dict[str, int], float]:
set_random_seed(self.seed) set_random_seed(self.seed)
permute_time = benchmark_permute( permute_time = benchmark_permute(
...@@ -227,6 +289,7 @@ class BenchmarkWorker: ...@@ -227,6 +289,7 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
use_customized_permute=use_customized_permute,
) )
unpermute_time = benchmark_unpermute( unpermute_time = benchmark_unpermute(
num_tokens, num_tokens,
...@@ -237,6 +300,7 @@ class BenchmarkWorker: ...@@ -237,6 +300,7 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
use_customized_permute=use_customized_permute,
) )
return permute_time, unpermute_time return permute_time, unpermute_time
...@@ -283,6 +347,7 @@ def main(args: argparse.Namespace): ...@@ -283,6 +347,7 @@ def main(args: argparse.Namespace):
dtype = torch.float16 if current_platform.is_rocm() else config.dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
...@@ -334,6 +399,7 @@ def main(args: argparse.Namespace): ...@@ -334,6 +399,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_customized_permute,
) )
for batch_size in batch_sizes for batch_size in batch_sizes
], ],
...@@ -353,9 +419,10 @@ if __name__ == "__main__": ...@@ -353,9 +419,10 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
) )
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
...@@ -607,7 +607,6 @@ def make_modular_kernel( ...@@ -607,7 +607,6 @@ def make_modular_kernel(
prepare_finalize = make_prepare_finalize( prepare_finalize = make_prepare_finalize(
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
) )
assert prepare_finalize is not None
fused_experts = make_fused_experts( fused_experts = make_fused_experts(
config.fused_experts_type, config.fused_experts_type,
......
...@@ -445,7 +445,6 @@ def make_prepare_finalize( ...@@ -445,7 +445,6 @@ def make_prepare_finalize(
) )
else: else:
return MoEPrepareAndFinalizeNoEP() return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts s = rank * num_local_experts
......
...@@ -20,6 +20,7 @@ def test_bert_models( ...@@ -20,6 +20,7 @@ def test_bert_models(
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.token_classify(example_prompts) vllm_outputs = vllm_model.token_classify(example_prompts)
......
...@@ -573,6 +573,21 @@ VLM_TEST_SETTINGS = { ...@@ -573,6 +573,21 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output, vllm_output_post_proc=model_utils.kimiv_vl_vllm_to_hf_output,
marks=[large_gpu_mark(min_gb=48)], marks=[large_gpu_mark(min_gb=48)],
), ),
"llama4": VLMTestInfo(
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
img_idx_to_prompt=lambda _: "<|image|>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
distributed_executor_backend="mp",
image_size_factors=[(0.25, 0.5, 1.0)],
hf_model_kwargs={"device_map": "auto"},
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
auto_cls=AutoModelForImageTextToText,
tensor_parallel_size=4,
marks=multi_gpu_marks(num_gpus=4),
),
"llava_next": VLMTestInfo( "llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"], models=["llava-hf/llava-v1.6-mistral-7b-hf"],
test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS), test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS),
...@@ -954,22 +969,6 @@ VLM_TEST_SETTINGS = { ...@@ -954,22 +969,6 @@ VLM_TEST_SETTINGS = {
) )
], ],
), ),
"llama4": VLMTestInfo(
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
img_idx_to_prompt=lambda _: "<|image|>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
distributed_executor_backend="mp",
image_size_factors=[(.25, 0.5, 1.0)],
hf_model_kwargs={"device_map": "auto"},
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
auto_cls=AutoModelForImageTextToText,
tensor_parallel_size=8,
vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
marks=[large_gpu_mark(min_gb=80), multi_gpu_marks(num_gpus=8)],
),
} }
...@@ -1322,4 +1321,4 @@ def test_custom_inputs_models_heavy( ...@@ -1322,4 +1321,4 @@ def test_custom_inputs_models_heavy(
test_case=test_case, test_case=test_case,
hf_runner=hf_runner, hf_runner=hf_runner,
vllm_runner=vllm_runner, vllm_runner=vllm_runner,
) )
\ No newline at end of file
...@@ -1061,7 +1061,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -1061,7 +1061,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel": _HfExamplesInfo( "Glm4MoeLiteMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash", "zai-org/GLM-4.7-Flash",
speculative_model="zai-org/GLM-4.7-Flash", speculative_model="zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0", is_available_online=False,
), ),
"LongCatFlashMTPModel": _HfExamplesInfo( "LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat", "meituan-longcat/LongCat-Flash-Chat",
...@@ -1165,4 +1165,4 @@ class HfExampleModels: ...@@ -1165,4 +1165,4 @@ class HfExampleModels:
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS)
\ No newline at end of file
...@@ -88,7 +88,6 @@ def can_initialize( ...@@ -88,7 +88,6 @@ def can_initialize(
[10 * GiB_bytes], [10 * GiB_bytes],
) )
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config
......
...@@ -2866,7 +2866,7 @@ def onednn_mm( ...@@ -2866,7 +2866,7 @@ def onednn_mm(
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm( torch.ops._C.onednn_mm(
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler
) )
return output return output
......
...@@ -130,20 +130,19 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -130,20 +130,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch( def dispatch( # type: ignore[override]
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
...@@ -251,4 +250,4 @@ class _CPUSHMDistributed: ...@@ -251,4 +250,4 @@ class _CPUSHMDistributed:
tensor_dict: dict[str, torch.Tensor] = {} tensor_dict: dict[str, torch.Tensor] = {}
for key, size, t in zip(key_list, size_list, value_list): for key, size, t in zip(key_list, size_list, value_list):
tensor_dict[key] = t.view(size) tensor_dict[key] = t.view(size)
return tensor_dict return tensor_dict
\ No newline at end of file
...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( def dispatch( # type: ignore[override]
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -332,7 +332,6 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -332,7 +332,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor] tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
...@@ -348,4 +347,4 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -348,4 +347,4 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states = self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states, is_sequence_parallel
) )
return hidden_states return hidden_states
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.distributed as dist import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend from flashinfer.comm.mnnvl import CommBackend as CommBackend
...@@ -23,3 +22,6 @@ class CustomCommunicator(CommBackend): ...@@ -23,3 +22,6 @@ class CustomCommunicator(CommBackend):
gathered = [None] * self.Get_size() gathered = [None] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group) dist.all_gather_object(gathered, data, group=self._group)
return gathered return gathered
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
\ No newline at end of file
...@@ -930,8 +930,8 @@ async def run_server_worker( ...@@ -930,8 +930,8 @@ async def run_server_worker(
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Get uvicorn log config (from file or with endpoint filter) # Load logging config for uvicorn if specified
log_config = get_uvicorn_log_config(args) log_config = load_log_config(args.log_config_file)
if log_config is not None: if log_config is not None:
uvicorn_kwargs["log_config"] = log_config uvicorn_kwargs["log_config"] = log_config
...@@ -988,4 +988,4 @@ if __name__ == "__main__": ...@@ -988,4 +988,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
validate_parsed_serve_args(args) validate_parsed_serve_args(args)
uvloop.run(run_server(args)) uvloop.run(run_server(args))
\ No newline at end of file
...@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig ...@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -744,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -744,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids), needs_detokenization=bool(request.echo and not request.return_token_ids),
) )
\ No newline at end of file
...@@ -881,7 +881,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -881,7 +881,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Deprecated, see profiler_config. # Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"), "VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
# Enable torch profiler to record shapes if set to 1. # Enable torch profiler to record shapes if set to 1.
# Deprecated, see profiler_config. # Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: ( "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES") os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
), ),
......
...@@ -176,9 +176,7 @@ def _fused_moe_lora_kernel( ...@@ -176,9 +176,7 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete # GDC wait waits for ALL programs in the prior kernel to complete
# before continuing. # before continuing.
# pre-fetch lora weight # pre-fetch lora weight
# add (offs_bn < N) mask; optional .ca for B
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
if USE_GDC and not IS_PRIMARY: if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
a = tl.load( a = tl.load(
...@@ -683,4 +681,4 @@ try: ...@@ -683,4 +681,4 @@ try:
except AttributeError: except AttributeError:
fused_moe_lora = _fused_moe_lora fused_moe_lora = _fused_moe_lora
fused_moe_lora_shrink = _fused_moe_lora_shrink fused_moe_lora_shrink = _fused_moe_lora_shrink
fused_moe_lora_expand = _fused_moe_lora_expand fused_moe_lora_expand = _fused_moe_lora_expand
\ No newline at end of file
...@@ -438,7 +438,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -438,7 +438,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tenso, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -316,11 +316,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -316,11 +316,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
# Determine split axis based on op type # Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0 # gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1 # down: RowParallel → split along dim 1
split_dim = ( split_dim = 1 if "down_proj.weight" in name else 0
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
total = loaded_weight.shape[split_dim] total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, ( assert total % num_chunks == 0, (
f"Shared expert weight dim {total} " f"Shared expert weight dim {total} "
...@@ -448,4 +444,4 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -448,4 +444,4 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
elif shared_weight: elif shared_weight:
# treat shared weights as top level weights # treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.") name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name return name
\ No newline at end of file
...@@ -1375,7 +1375,7 @@ class DeepseekV2ForCausalLM( ...@@ -1375,7 +1375,7 @@ class DeepseekV2ForCausalLM(
break break
else: else:
is_expert_weight = False is_expert_weight = False
# Special handling: when AITER fusion_shared_experts is enabled, # Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor # checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices # without explicit expert indices
...@@ -1487,6 +1487,7 @@ class DeepseekV2ForCausalLM( ...@@ -1487,6 +1487,7 @@ class DeepseekV2ForCausalLM(
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
...@@ -1511,4 +1512,4 @@ def get_spec_layer_idx_from_weight_name( ...@@ -1511,4 +1512,4 @@ def get_spec_layer_idx_from_weight_name(
for i in range(config.num_nextn_predict_layers): for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx + i}."): if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i return layer_idx + i
return None return None
\ No newline at end of file
...@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module): ...@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -678,4 +678,4 @@ class FalconH1ForCausalLM( ...@@ -678,4 +678,4 @@ class FalconH1ForCausalLM(
if self.tie_word_embeddings: if self.tie_word_embeddings:
loaded_params.add("lm_head.weight") loaded_params.add("lm_head.weight")
return loaded_params return loaded_params
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment